from typing import Optional, List import copy from easydict import EasyDict from ding.utils import find_free_port, find_free_port_slurm, node_to_partition, node_to_host, pretty_print, \ DEFAULT_K8S_COLLECTOR_PORT, DEFAULT_K8S_LEARNER_PORT, DEFAULT_K8S_COORDINATOR_PORT from dizoo.classic_control.cartpole.config.parallel import cartpole_dqn_config default_host = '0.0.0.0' default_port = 22270 def set_host_port(cfg: EasyDict, coordinator_host: str, learner_host: str, collector_host: str) -> EasyDict: cfg.coordinator.host = coordinator_host if cfg.coordinator.port == 'auto': cfg.coordinator.port = find_free_port(coordinator_host) learner_count = 0 collector_count = 0 for k in cfg.keys(): if k == 'learner_aggregator': raise NotImplementedError if k.startswith('learner'): if cfg[k].host == 'auto': if isinstance(learner_host, list): cfg[k].host = learner_host[learner_count] learner_count += 1 elif isinstance(learner_host, str): cfg[k].host = learner_host else: raise TypeError("not support learner_host type: {}".format(learner_host)) if cfg[k].port == 'auto': cfg[k].port = find_free_port(cfg[k].host) cfg[k].aggregator = False if k.startswith('collector'): if cfg[k].host == 'auto': if isinstance(collector_host, list): cfg[k].host = collector_host[collector_count] collector_count += 1 elif isinstance(collector_host, str): cfg[k].host = collector_host else: raise TypeError("not support collector_host type: {}".format(collector_host)) if cfg[k].port == 'auto': cfg[k].port = find_free_port(cfg[k].host) return cfg def set_host_port_slurm(cfg: EasyDict, coordinator_host: str, learner_node: list, collector_node: list) -> EasyDict: cfg.coordinator.host = coordinator_host if cfg.coordinator.port == 'auto': cfg.coordinator.port = find_free_port(coordinator_host) if isinstance(learner_node, str): learner_node = [learner_node] if isinstance(collector_node, str): collector_node = [collector_node] learner_count, collector_count = 0, 0 learner_multi = {} for k in cfg.keys(): if learner_node is not None and k.startswith('learner'): node = learner_node[learner_count % len(learner_node)] cfg[k].node = node cfg[k].partition = node_to_partition(node) gpu_num = cfg[k].gpu_num if cfg[k].host == 'auto': cfg[k].host = node_to_host(node) if cfg[k].port == 'auto': if gpu_num == 1: cfg[k].port = find_free_port_slurm(node) learner_multi[k] = False else: cfg[k].port = [find_free_port_slurm(node) for _ in range(gpu_num)] learner_multi[k] = True learner_count += 1 if collector_node is not None and k.startswith('collector'): node = collector_node[collector_count % len(collector_node)] cfg[k].node = node cfg[k].partition = node_to_partition(node) if cfg[k].host == 'auto': cfg[k].host = node_to_host(node) if cfg[k].port == 'auto': cfg[k].port = find_free_port_slurm(node) collector_count += 1 for k, flag in learner_multi.items(): if flag: host = cfg[k].host learner_interaction_cfg = {str(i): [str(i), host, p] for i, p in enumerate(cfg[k].port)} aggregator_cfg = dict( master=dict( host=host, port=find_free_port_slurm(cfg[k].node), ), slave=dict( host=host, port=find_free_port_slurm(cfg[k].node), ), learner=learner_interaction_cfg, node=cfg[k].node, partition=cfg[k].partition, ) cfg[k].aggregator = True cfg['learner_aggregator' + k[7:]] = aggregator_cfg else: cfg[k].aggregator = False return cfg def set_host_port_k8s(cfg: EasyDict, coordinator_port: int, learner_port: int, collector_port: int) -> EasyDict: cfg.coordinator.host = default_host cfg.coordinator.port = coordinator_port if coordinator_port is not None else DEFAULT_K8S_COORDINATOR_PORT base_learner_cfg = None base_collector_cfg = None if learner_port is None: learner_port = DEFAULT_K8S_LEARNER_PORT if collector_port is None: collector_port = DEFAULT_K8S_COLLECTOR_PORT for k in cfg.keys(): if k.startswith('learner'): # create the base learner config if base_learner_cfg is None: base_learner_cfg = copy.deepcopy(cfg[k]) base_learner_cfg.host = default_host base_learner_cfg.port = learner_port cfg[k].port = learner_port elif k.startswith('collector'): # create the base collector config if base_collector_cfg is None: base_collector_cfg = copy.deepcopy(cfg[k]) base_collector_cfg.host = default_host base_collector_cfg.port = collector_port cfg[k].port = collector_port cfg['learner'] = base_learner_cfg cfg['collector'] = base_collector_cfg return cfg def set_learner_interaction_for_coordinator(cfg: EasyDict) -> EasyDict: cfg.coordinator.learner = {} for k in cfg.keys(): if k.startswith('learner') and not k.startswith('learner_aggregator'): if cfg[k].aggregator: dst_k = 'learner_aggregator' + k[7:] cfg.coordinator.learner[k] = [k, cfg[dst_k].slave.host, cfg[dst_k].slave.port] else: dst_k = k cfg.coordinator.learner[k] = [k, cfg[dst_k].host, cfg[dst_k].port] return cfg def set_collector_interaction_for_coordinator(cfg: EasyDict) -> EasyDict: cfg.coordinator.collector = {} for k in cfg.keys(): if k.startswith('collector'): cfg.coordinator.collector[k] = [k, cfg[k].host, cfg[k].port] return cfg def set_system_cfg(cfg: EasyDict) -> EasyDict: learner_num = cfg.main.policy.learn.learner.learner_num collector_num = cfg.main.policy.collect.collector.collector_num path_data = cfg.system.path_data path_policy = cfg.system.path_policy coordinator_cfg = cfg.system.coordinator communication_mode = cfg.system.communication_mode assert communication_mode in ['auto'], communication_mode learner_gpu_num = cfg.system.learner_gpu_num learner_multi_gpu = learner_gpu_num > 1 new_cfg = dict(coordinator=dict( host='auto', port='auto', )) new_cfg['coordinator'].update(coordinator_cfg) for i in range(learner_num): new_cfg[f'learner{i}'] = dict( type=cfg.system.comm_learner.type, import_names=cfg.system.comm_learner.import_names, host='auto', port='auto', path_data=path_data, path_policy=path_policy, multi_gpu=learner_multi_gpu, gpu_num=learner_gpu_num, ) for i in range(collector_num): new_cfg[f'collector{i}'] = dict( type=cfg.system.comm_collector.type, import_names=cfg.system.comm_collector.import_names, host='auto', port='auto', path_data=path_data, path_policy=path_policy, ) return EasyDict(new_cfg) def parallel_transform( cfg: dict, coordinator_host: Optional[str] = None, learner_host: Optional[List[str]] = None, collector_host: Optional[List[str]] = None ) -> None: coordinator_host = default_host if coordinator_host is None else coordinator_host collector_host = default_host if collector_host is None else collector_host learner_host = default_host if learner_host is None else learner_host cfg = EasyDict(cfg) cfg.system = set_system_cfg(cfg) cfg.system = set_host_port(cfg.system, coordinator_host, learner_host, collector_host) cfg.system = set_learner_interaction_for_coordinator(cfg.system) cfg.system = set_collector_interaction_for_coordinator(cfg.system) return cfg def parallel_transform_slurm( cfg: dict, coordinator_host: Optional[str] = None, learner_node: Optional[List[str]] = None, collector_node: Optional[List[str]] = None ) -> None: cfg = EasyDict(cfg) cfg.system = set_system_cfg(cfg) cfg.system = set_host_port_slurm(cfg.system, coordinator_host, learner_node, collector_node) cfg.system = set_learner_interaction_for_coordinator(cfg.system) cfg.system = set_collector_interaction_for_coordinator(cfg.system) pretty_print(cfg) return cfg def parallel_transform_k8s( cfg: dict, coordinator_port: Optional[int] = None, learner_port: Optional[int] = None, collector_port: Optional[int] = None ) -> None: cfg = EasyDict(cfg) cfg.system = set_system_cfg(cfg) cfg.system = set_host_port_k8s(cfg.system, coordinator_port, learner_port, collector_port) # learner/collector is created by opereator, so the following field is placeholder cfg.system.coordinator.collector = {} cfg.system.coordinator.learner = {} pretty_print(cfg) return cfg def save_config_formatted(config_: dict, path: str = 'formatted_total_config.py') -> None: """ Overview: save formatted configuration to python file that can be read by serial_pipeline directly. Arguments: - config (:obj:`dict`): Config dict - path (:obj:`str`): Path of python file """ with open(path, "w") as f: f.write('from easydict import EasyDict\n\n') f.write('main_config = dict(\n') f.write(" exp_name='{}',\n".format(config_.exp_name)) for k, v in config_.items(): if (k == 'env'): f.write(' env=dict(\n') for k2, v2 in v.items(): if (k2 != 'type' and k2 != 'import_names' and k2 != 'manager'): if (isinstance(v2, str)): f.write(" {}='{}',\n".format(k2, v2)) else: f.write(" {}={},\n".format(k2, v2)) if (k2 == 'manager'): f.write(" manager=dict(\n") for k3, v3 in v2.items(): if (v3 != 'cfg_type' and v3 != 'type'): if (isinstance(v3, str)): f.write(" {}='{}',\n".format(k3, v3)) elif v3 == float('inf'): f.write(" {}=float('{}'),\n".format(k3, v3)) else: f.write(" {}={},\n".format(k3, v3)) f.write(" ),\n") f.write(" ),\n") if (k == 'policy'): f.write(' policy=dict(\n') for k2, v2 in v.items(): if (k2 != 'type' and k2 != 'learn' and k2 != 'collect' and k2 != 'eval' and k2 != 'other' and k2 != 'model'): if (isinstance(v2, str)): f.write(" {}='{}',\n".format(k2, v2)) else: f.write(" {}={},\n".format(k2, v2)) elif (k2 == 'learn'): f.write(" learn=dict(\n") for k3, v3 in v2.items(): if (k3 != 'learner'): if (isinstance(v3, str)): f.write(" {}='{}',\n".format(k3, v3)) else: f.write(" {}={},\n".format(k3, v3)) if (k3 == 'learner'): f.write(" learner=dict(\n") for k4, v4 in v3.items(): if (k4 != 'dataloader' and k4 != 'hook'): if (isinstance(v4, str)): f.write(" {}='{}',\n".format(k4, v4)) else: f.write(" {}={},\n".format(k4, v4)) else: if (k4 == 'dataloader'): f.write(" dataloader=dict(\n") for k5, v5 in v4.items(): if (isinstance(v5, str)): f.write(" {}='{}',\n".format(k5, v5)) else: f.write(" {}={},\n".format(k5, v5)) f.write(" ),\n") if (k4 == 'hook'): f.write(" hook=dict(\n") for k5, v5 in v4.items(): if (isinstance(v5, str)): f.write(" {}='{}',\n".format(k5, v5)) else: f.write(" {}={},\n".format(k5, v5)) f.write(" ),\n") f.write(" ),\n") f.write(" ),\n") elif (k2 == 'collect'): f.write(" collect=dict(\n") for k3, v3 in v2.items(): if (k3 != 'collector'): if (isinstance(v3, str)): f.write(" {}='{}',\n".format(k3, v3)) else: f.write(" {}={},\n".format(k3, v3)) if (k3 == 'collector'): f.write(" collector=dict(\n") for k4, v4 in v3.items(): if (isinstance(v4, str)): f.write(" {}='{}',\n".format(k4, v4)) else: f.write(" {}={},\n".format(k4, v4)) f.write(" ),\n") f.write(" ),\n") elif (k2 == 'eval'): f.write(" eval=dict(\n") for k3, v3 in v2.items(): if (k3 != 'evaluator'): if (isinstance(v3, str)): f.write(" {}='{}',\n".format(k3, v3)) else: f.write(" {}={},\n".format(k3, v3)) if (k3 == 'evaluator'): f.write(" evaluator=dict(\n") for k4, v4 in v3.items(): if (isinstance(v4, str)): f.write(" {}='{}',\n".format(k4, v4)) else: f.write(" {}={},\n".format(k4, v4)) f.write(" ),\n") f.write(" ),\n") elif (k2 == 'model'): f.write(" model=dict(\n") for k3, v3 in v2.items(): if (isinstance(v3, str)): f.write(" {}='{}',\n".format(k3, v3)) else: f.write(" {}={},\n".format(k3, v3)) f.write(" ),\n") elif (k2 == 'other'): f.write(" other=dict(\n") for k3, v3 in v2.items(): if (k3 == 'replay_buffer'): f.write(" replay_buffer=dict(\n") for k4, v4 in v3.items(): if (k4 != 'monitor' and k4 != 'thruput_controller'): if (isinstance(v4, dict)): f.write(" {}=dict(\n".format(k4)) for k5, v5 in v4.items(): if (isinstance(v5, str)): f.write(" {}='{}',\n".format(k5, v5)) elif v5 == float('inf'): f.write(" {}=float('{}'),\n".format(k5, v5)) elif (isinstance(v5, dict)): f.write(" {}=dict(\n".format(k5)) for k6, v6 in v5.items(): if (isinstance(v6, str)): f.write(" {}='{}',\n".format(k6, v6)) elif v6 == float('inf'): f.write( " {}=float('{}'),\n".format( k6, v6 ) ) elif (isinstance(v6, dict)): f.write(" {}=dict(\n".format(k6)) for k7, v7 in v6.items(): if (isinstance(v7, str)): f.write( " {}='{}',\n".format( k7, v7 ) ) elif v7 == float('inf'): f.write( " {}=float('{}'),\n". format(k7, v7) ) else: f.write( " {}={},\n".format( k7, v7 ) ) f.write(" ),\n") else: f.write(" {}={},\n".format(k6, v6)) f.write(" ),\n") else: f.write(" {}={},\n".format(k5, v5)) f.write(" ),\n") else: if (isinstance(v4, str)): f.write(" {}='{}',\n".format(k4, v4)) elif v4 == float('inf'): f.write(" {}=float('{}'),\n".format(k4, v4)) else: f.write(" {}={},\n".format(k4, v4)) else: if (k4 == 'monitor'): f.write(" monitor=dict(\n") for k5, v5 in v4.items(): if (k5 == 'log_path'): if (isinstance(v5, str)): f.write(" {}='{}',\n".format(k5, v5)) else: f.write(" {}={},\n".format(k5, v5)) else: f.write(" {}=dict(\n".format(k5)) for k6, v6 in v5.items(): if (isinstance(v6, str)): f.write(" {}='{}',\n".format(k6, v6)) else: f.write(" {}={},\n".format(k6, v6)) f.write(" ),\n") f.write(" ),\n") if (k4 == 'thruput_controller'): f.write(" thruput_controller=dict(\n") for k5, v5 in v4.items(): if (isinstance(v5, dict)): f.write(" {}=dict(\n".format(k5)) for k6, v6 in v5.items(): if (isinstance(v6, str)): f.write(" {}='{}',\n".format(k6, v6)) elif v6 == float('inf'): f.write( " {}=float('{}'),\n".format( k6, v6 ) ) else: f.write(" {}={},\n".format(k6, v6)) f.write(" ),\n") else: if (isinstance(v5, str)): f.write(" {}='{}',\n".format(k5, v5)) else: f.write(" {}={},\n".format(k5, v5)) f.write(" ),\n") f.write(" ),\n") f.write(" ),\n") f.write(" ),\n)\n") f.write('main_config = EasyDict(main_config)\n') f.write('main_config = main_config\n') f.write('create_config = dict(\n') for k, v in config_.items(): if (k == 'env'): f.write(' env=dict(\n') for k2, v2 in v.items(): if (k2 == 'type' or k2 == 'import_names'): if isinstance(v2, str): f.write(" {}='{}',\n".format(k2, v2)) else: f.write(" {}={},\n".format(k2, v2)) f.write(" ),\n") for k2, v2 in v.items(): if (k2 == 'manager'): f.write(' env_manager=dict(\n') for k3, v3 in v2.items(): if (k3 == 'cfg_type' or k3 == 'type'): if (isinstance(v3, str)): f.write(" {}='{}',\n".format(k3, v3)) else: f.write(" {}={},\n".format(k3, v3)) f.write(" ),\n") policy_type = config_.policy.type if '_command' in policy_type: f.write(" policy=dict(type='{}'),\n".format(policy_type[0:len(policy_type) - 8])) else: f.write(" policy=dict(type='{}'),\n".format(policy_type)) f.write(")\n") f.write('create_config = EasyDict(create_config)\n') f.write('create_config = create_config\n') parallel_test_main_config = cartpole_dqn_config parallel_test_create_config = dict( env=dict( type='cartpole', import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], ), env_manager=dict(type='subprocess'), policy=dict(type='dqn_command'), comm_learner=dict( type='flask_fs', import_names=['ding.worker.learner.comm.flask_fs_learner'], ), comm_collector=dict( type='flask_fs', import_names=['ding.worker.collector.comm.flask_fs_collector'], ), learner=dict( type='base', import_names=['ding.worker.learner.base_learner'], ), collector=dict( type='zergling', import_names=['ding.worker.collector.zergling_parallel_collector'], ), commander=dict( type='naive', import_names=['ding.worker.coordinator.base_parallel_commander'], ), ) parallel_test_create_config = EasyDict(parallel_test_create_config) parallel_test_system_config = dict( coordinator=dict(), path_data='.', path_policy='.', communication_mode='auto', learner_gpu_num=1, ) parallel_test_system_config = EasyDict(parallel_test_system_config)