|
from typing import List, Union |
|
import os |
|
import copy |
|
import click |
|
from click.core import Context, Option |
|
import numpy as np |
|
|
|
from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__ |
|
from ding.config import read_config |
|
from .predefined_config import get_predefined_config |
|
|
|
|
|
def print_version(ctx: Context, param: Option, value: bool) -> None: |
|
if not value or ctx.resilient_parsing: |
|
return |
|
click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__)) |
|
click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__)) |
|
ctx.exit() |
|
|
|
|
|
def print_registry(ctx: Context, param: Option, value: str): |
|
if value is None: |
|
return |
|
from ding.utils import registries |
|
if value not in registries: |
|
click.echo('[ERROR]: not support registry name: {}'.format(value)) |
|
else: |
|
registered_info = registries[value].query_details() |
|
click.echo('Available {}: [{}]'.format(value, '|'.join(registered_info.keys()))) |
|
for alias, info in registered_info.items(): |
|
click.echo('\t{}: registered at {}#{}'.format(alias, info[0], info[1])) |
|
ctx.exit() |
|
|
|
|
|
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) |
|
|
|
|
|
@click.command(context_settings=CONTEXT_SETTINGS) |
|
@click.option( |
|
'-v', |
|
'--version', |
|
is_flag=True, |
|
callback=print_version, |
|
expose_value=False, |
|
is_eager=True, |
|
help="Show package's version information." |
|
) |
|
@click.option( |
|
'-q', |
|
'--query-registry', |
|
type=str, |
|
callback=print_registry, |
|
expose_value=False, |
|
is_eager=True, |
|
help='query registered module or function, show name and path' |
|
) |
|
@click.option( |
|
'-m', |
|
'--mode', |
|
type=click.Choice( |
|
[ |
|
'serial', |
|
'serial_onpolicy', |
|
'serial_sqil', |
|
'serial_dqfd', |
|
'serial_trex', |
|
'serial_trex_onpolicy', |
|
'parallel', |
|
'dist', |
|
'eval', |
|
'serial_reward_model', |
|
'serial_gail', |
|
'serial_offline', |
|
'serial_ngu', |
|
] |
|
), |
|
help='serial-train or parallel-train or dist-train or eval' |
|
) |
|
@click.option('-c', '--config', type=str, help='Path to DRL experiment config') |
|
@click.option( |
|
'-s', |
|
'--seed', |
|
type=int, |
|
default=[0], |
|
multiple=True, |
|
help='random generator seed(for all the possible package: random, numpy, torch and user env)' |
|
) |
|
@click.option('-e', '--env', type=str, help='RL env name') |
|
@click.option('-p', '--policy', type=str, help='DRL policy name') |
|
@click.option('--exp-name', type=str, help='experiment directory name') |
|
@click.option('--train-iter', type=str, default='1e8', help='Maximum policy update iterations in training') |
|
@click.option('--env-step', type=str, default='1e8', help='Maximum collected environment steps for training') |
|
@click.option('--load-path', type=str, default=None, help='Path to load ckpt') |
|
@click.option('--replay-path', type=str, default=None, help='Path to save replay') |
|
|
|
@click.option('--enable-total-log', type=bool, help='whether enable the total DI-engine system log', default=False) |
|
@click.option('--disable-flask-log', type=bool, help='whether disable flask log', default=True) |
|
@click.option( |
|
'-P', '--platform', type=click.Choice(['local', 'slurm', 'k8s']), help='local or slurm or k8s', default='local' |
|
) |
|
@click.option( |
|
'-M', |
|
'--module', |
|
type=click.Choice(['config', 'collector', 'learner', 'coordinator', 'learner_aggregator', 'spawn_learner']), |
|
help='dist module type' |
|
) |
|
@click.option('--module-name', type=str, help='dist module name') |
|
@click.option('-cdh', '--coordinator-host', type=str, help='coordinator host', default='0.0.0.0') |
|
@click.option('-cdp', '--coordinator-port', type=int, help='coordinator port') |
|
@click.option('-lh', '--learner-host', type=str, help='learner host', default='0.0.0.0') |
|
@click.option('-lp', '--learner-port', type=int, help='learner port') |
|
@click.option('-clh', '--collector-host', type=str, help='collector host', default='0.0.0.0') |
|
@click.option('-clp', '--collector-port', type=int, help='collector port') |
|
@click.option('-agh', '--aggregator-host', type=str, help='aggregator slave host', default='0.0.0.0') |
|
@click.option('-agp', '--aggregator-port', type=int, help='aggregator slave port') |
|
@click.option('--add', type=click.Choice(['collector', 'learner']), help='add replicas type') |
|
@click.option('--delete', type=click.Choice(['collector', 'learner']), help='delete replicas type') |
|
@click.option('--restart', type=click.Choice(['collector', 'learner']), help='restart replicas type') |
|
@click.option('--kubeconfig', type=str, default=None, help='the path of Kubernetes configuration file') |
|
@click.option('-cdn', '--coordinator-name', type=str, default=None, help='coordinator name') |
|
@click.option('-ns', '--namespace', type=str, default=None, help='job namespace') |
|
@click.option('-rs', '--replicas', type=int, default=1, help='number of replicas to add/delete/restart') |
|
@click.option('-rpn', '--restart-pod-name', type=str, default=None, help='restart pod name') |
|
@click.option('--cpus', type=int, default=0, help='The requested CPU, read the value from DIJob yaml by default') |
|
@click.option('--gpus', type=int, default=0, help='The requested GPU, read the value from DIJob yaml by default') |
|
@click.option( |
|
'--memory', type=str, default=None, help='The requested Memory, read the value from DIJob yaml by default' |
|
) |
|
@click.option( |
|
'--profile', |
|
type=str, |
|
default=None, |
|
help='profile Time cost by cProfile, and save the files into the specified folder path' |
|
) |
|
def cli( |
|
|
|
mode: str, |
|
config: str, |
|
seed: Union[int, List], |
|
exp_name: str, |
|
env: str, |
|
policy: str, |
|
train_iter: str, |
|
env_step: str, |
|
load_path: str, |
|
replay_path: str, |
|
|
|
platform: str, |
|
coordinator_host: str, |
|
coordinator_port: int, |
|
learner_host: str, |
|
learner_port: int, |
|
collector_host: str, |
|
collector_port: int, |
|
aggregator_host: str, |
|
aggregator_port: int, |
|
enable_total_log: bool, |
|
disable_flask_log: bool, |
|
module: str, |
|
module_name: str, |
|
|
|
add: str, |
|
delete: str, |
|
restart: str, |
|
kubeconfig: str, |
|
coordinator_name: str, |
|
namespace: str, |
|
replicas: int, |
|
cpus: int, |
|
gpus: int, |
|
memory: str, |
|
restart_pod_name: str, |
|
profile: str, |
|
): |
|
if profile is not None: |
|
from ..utils.profiler_helper import Profiler |
|
profiler = Profiler() |
|
profiler.profile(profile) |
|
|
|
train_iter = int(float(train_iter)) |
|
env_step = int(float(env_step)) |
|
|
|
def run_single_pipeline(seed, config): |
|
if config is None: |
|
config = get_predefined_config(env, policy) |
|
else: |
|
config = read_config(config) |
|
if exp_name is not None: |
|
config[0].exp_name = exp_name |
|
|
|
if mode == 'serial': |
|
from .serial_entry import serial_pipeline |
|
serial_pipeline(config, seed, max_train_iter=train_iter, max_env_step=env_step) |
|
elif mode == 'serial_onpolicy': |
|
from .serial_entry_onpolicy import serial_pipeline_onpolicy |
|
serial_pipeline_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step) |
|
elif mode == 'serial_sqil': |
|
from .serial_entry_sqil import serial_pipeline_sqil |
|
expert_config = input("Enter the name of the config you used to generate your expert model: ") |
|
serial_pipeline_sqil(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step) |
|
elif mode == 'serial_reward_model': |
|
from .serial_entry_reward_model_offpolicy import serial_pipeline_reward_model_offpolicy |
|
serial_pipeline_reward_model_offpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step) |
|
elif mode == 'serial_gail': |
|
from .serial_entry_gail import serial_pipeline_gail |
|
expert_config = input("Enter the name of the config you used to generate your expert model: ") |
|
serial_pipeline_gail( |
|
config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step, collect_data=True |
|
) |
|
elif mode == 'serial_dqfd': |
|
from .serial_entry_dqfd import serial_pipeline_dqfd |
|
expert_config = input("Enter the name of the config you used to generate your expert model: ") |
|
assert (expert_config == config[:config.find('_dqfd')] + '_dqfd_config.py'), "DQFD only supports "\ |
|
+ "the models used in q learning now; However, one should still type the DQFD config in this "\ |
|
+ "place, i.e., {}{}".format(config[:config.find('_dqfd')], '_dqfd_config.py') |
|
serial_pipeline_dqfd(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step) |
|
elif mode == 'serial_trex': |
|
from .serial_entry_trex import serial_pipeline_trex |
|
serial_pipeline_trex(config, seed, max_train_iter=train_iter, max_env_step=env_step) |
|
elif mode == 'serial_trex_onpolicy': |
|
from .serial_entry_trex_onpolicy import serial_pipeline_trex_onpolicy |
|
serial_pipeline_trex_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step) |
|
elif mode == 'serial_offline': |
|
from .serial_entry_offline import serial_pipeline_offline |
|
serial_pipeline_offline(config, seed, max_train_iter=train_iter) |
|
elif mode == 'serial_ngu': |
|
from .serial_entry_ngu import serial_pipeline_ngu |
|
serial_pipeline_ngu(config, seed, max_train_iter=train_iter) |
|
elif mode == 'parallel': |
|
from .parallel_entry import parallel_pipeline |
|
parallel_pipeline(config, seed, enable_total_log, disable_flask_log) |
|
elif mode == 'dist': |
|
from .dist_entry import dist_launch_coordinator, dist_launch_collector, dist_launch_learner, \ |
|
dist_prepare_config, dist_launch_learner_aggregator, dist_launch_spawn_learner, \ |
|
dist_add_replicas, dist_delete_replicas, dist_restart_replicas |
|
if module == 'config': |
|
dist_prepare_config( |
|
config, seed, platform, coordinator_host, learner_host, collector_host, coordinator_port, |
|
learner_port, collector_port |
|
) |
|
elif module == 'coordinator': |
|
dist_launch_coordinator(config, seed, coordinator_port, disable_flask_log) |
|
elif module == 'learner_aggregator': |
|
dist_launch_learner_aggregator( |
|
config, seed, aggregator_host, aggregator_port, module_name, disable_flask_log |
|
) |
|
|
|
elif module == 'collector': |
|
dist_launch_collector(config, seed, collector_port, module_name, disable_flask_log) |
|
elif module == 'learner': |
|
dist_launch_learner(config, seed, learner_port, module_name, disable_flask_log) |
|
elif module == 'spawn_learner': |
|
dist_launch_spawn_learner(config, seed, learner_port, module_name, disable_flask_log) |
|
elif add in ['collector', 'learner']: |
|
dist_add_replicas(add, kubeconfig, replicas, coordinator_name, namespace, cpus, gpus, memory) |
|
elif delete in ['collector', 'learner']: |
|
dist_delete_replicas(delete, kubeconfig, replicas, coordinator_name, namespace) |
|
elif restart in ['collector', 'learner']: |
|
dist_restart_replicas(restart, kubeconfig, coordinator_name, namespace, restart_pod_name) |
|
else: |
|
raise Exception |
|
elif mode == 'eval': |
|
from .application_entry import eval |
|
eval(config, seed, load_path=load_path, replay_path=replay_path) |
|
|
|
if mode is None: |
|
raise RuntimeError("Please indicate at least one argument.") |
|
|
|
if isinstance(seed, (list, tuple)): |
|
assert len(seed) > 0, "Please input at least 1 seed" |
|
if len(seed) == 1: |
|
run_single_pipeline(seed[0], config) |
|
else: |
|
if exp_name is None: |
|
multi_exp_root = os.path.basename(config).split('.')[0] + '_result' |
|
else: |
|
multi_exp_root = exp_name |
|
if not os.path.exists(multi_exp_root): |
|
os.makedirs(multi_exp_root) |
|
abs_config_path = os.path.abspath(config) |
|
origin_root = os.getcwd() |
|
for s in seed: |
|
seed_exp_root = os.path.join(multi_exp_root, 'seed{}'.format(s)) |
|
if not os.path.exists(seed_exp_root): |
|
os.makedirs(seed_exp_root) |
|
os.chdir(seed_exp_root) |
|
run_single_pipeline(s, abs_config_path) |
|
os.chdir(origin_root) |
|
else: |
|
raise TypeError("invalid seed type: {}".format(type(seed))) |
|
|