zjowowen's picture
init space
079c32c
raw
history blame
No virus
28.1 kB
from typing import Optional, List
import copy
from easydict import EasyDict
from ding.utils import find_free_port, find_free_port_slurm, node_to_partition, node_to_host, pretty_print, \
DEFAULT_K8S_COLLECTOR_PORT, DEFAULT_K8S_LEARNER_PORT, DEFAULT_K8S_COORDINATOR_PORT
from dizoo.classic_control.cartpole.config.parallel import cartpole_dqn_config
default_host = '0.0.0.0'
default_port = 22270
def set_host_port(cfg: EasyDict, coordinator_host: str, learner_host: str, collector_host: str) -> EasyDict:
cfg.coordinator.host = coordinator_host
if cfg.coordinator.port == 'auto':
cfg.coordinator.port = find_free_port(coordinator_host)
learner_count = 0
collector_count = 0
for k in cfg.keys():
if k == 'learner_aggregator':
raise NotImplementedError
if k.startswith('learner'):
if cfg[k].host == 'auto':
if isinstance(learner_host, list):
cfg[k].host = learner_host[learner_count]
learner_count += 1
elif isinstance(learner_host, str):
cfg[k].host = learner_host
else:
raise TypeError("not support learner_host type: {}".format(learner_host))
if cfg[k].port == 'auto':
cfg[k].port = find_free_port(cfg[k].host)
cfg[k].aggregator = False
if k.startswith('collector'):
if cfg[k].host == 'auto':
if isinstance(collector_host, list):
cfg[k].host = collector_host[collector_count]
collector_count += 1
elif isinstance(collector_host, str):
cfg[k].host = collector_host
else:
raise TypeError("not support collector_host type: {}".format(collector_host))
if cfg[k].port == 'auto':
cfg[k].port = find_free_port(cfg[k].host)
return cfg
def set_host_port_slurm(cfg: EasyDict, coordinator_host: str, learner_node: list, collector_node: list) -> EasyDict:
cfg.coordinator.host = coordinator_host
if cfg.coordinator.port == 'auto':
cfg.coordinator.port = find_free_port(coordinator_host)
if isinstance(learner_node, str):
learner_node = [learner_node]
if isinstance(collector_node, str):
collector_node = [collector_node]
learner_count, collector_count = 0, 0
learner_multi = {}
for k in cfg.keys():
if learner_node is not None and k.startswith('learner'):
node = learner_node[learner_count % len(learner_node)]
cfg[k].node = node
cfg[k].partition = node_to_partition(node)
gpu_num = cfg[k].gpu_num
if cfg[k].host == 'auto':
cfg[k].host = node_to_host(node)
if cfg[k].port == 'auto':
if gpu_num == 1:
cfg[k].port = find_free_port_slurm(node)
learner_multi[k] = False
else:
cfg[k].port = [find_free_port_slurm(node) for _ in range(gpu_num)]
learner_multi[k] = True
learner_count += 1
if collector_node is not None and k.startswith('collector'):
node = collector_node[collector_count % len(collector_node)]
cfg[k].node = node
cfg[k].partition = node_to_partition(node)
if cfg[k].host == 'auto':
cfg[k].host = node_to_host(node)
if cfg[k].port == 'auto':
cfg[k].port = find_free_port_slurm(node)
collector_count += 1
for k, flag in learner_multi.items():
if flag:
host = cfg[k].host
learner_interaction_cfg = {str(i): [str(i), host, p] for i, p in enumerate(cfg[k].port)}
aggregator_cfg = dict(
master=dict(
host=host,
port=find_free_port_slurm(cfg[k].node),
),
slave=dict(
host=host,
port=find_free_port_slurm(cfg[k].node),
),
learner=learner_interaction_cfg,
node=cfg[k].node,
partition=cfg[k].partition,
)
cfg[k].aggregator = True
cfg['learner_aggregator' + k[7:]] = aggregator_cfg
else:
cfg[k].aggregator = False
return cfg
def set_host_port_k8s(cfg: EasyDict, coordinator_port: int, learner_port: int, collector_port: int) -> EasyDict:
cfg.coordinator.host = default_host
cfg.coordinator.port = coordinator_port if coordinator_port is not None else DEFAULT_K8S_COORDINATOR_PORT
base_learner_cfg = None
base_collector_cfg = None
if learner_port is None:
learner_port = DEFAULT_K8S_LEARNER_PORT
if collector_port is None:
collector_port = DEFAULT_K8S_COLLECTOR_PORT
for k in cfg.keys():
if k.startswith('learner'):
# create the base learner config
if base_learner_cfg is None:
base_learner_cfg = copy.deepcopy(cfg[k])
base_learner_cfg.host = default_host
base_learner_cfg.port = learner_port
cfg[k].port = learner_port
elif k.startswith('collector'):
# create the base collector config
if base_collector_cfg is None:
base_collector_cfg = copy.deepcopy(cfg[k])
base_collector_cfg.host = default_host
base_collector_cfg.port = collector_port
cfg[k].port = collector_port
cfg['learner'] = base_learner_cfg
cfg['collector'] = base_collector_cfg
return cfg
def set_learner_interaction_for_coordinator(cfg: EasyDict) -> EasyDict:
cfg.coordinator.learner = {}
for k in cfg.keys():
if k.startswith('learner') and not k.startswith('learner_aggregator'):
if cfg[k].aggregator:
dst_k = 'learner_aggregator' + k[7:]
cfg.coordinator.learner[k] = [k, cfg[dst_k].slave.host, cfg[dst_k].slave.port]
else:
dst_k = k
cfg.coordinator.learner[k] = [k, cfg[dst_k].host, cfg[dst_k].port]
return cfg
def set_collector_interaction_for_coordinator(cfg: EasyDict) -> EasyDict:
cfg.coordinator.collector = {}
for k in cfg.keys():
if k.startswith('collector'):
cfg.coordinator.collector[k] = [k, cfg[k].host, cfg[k].port]
return cfg
def set_system_cfg(cfg: EasyDict) -> EasyDict:
learner_num = cfg.main.policy.learn.learner.learner_num
collector_num = cfg.main.policy.collect.collector.collector_num
path_data = cfg.system.path_data
path_policy = cfg.system.path_policy
coordinator_cfg = cfg.system.coordinator
communication_mode = cfg.system.communication_mode
assert communication_mode in ['auto'], communication_mode
learner_gpu_num = cfg.system.learner_gpu_num
learner_multi_gpu = learner_gpu_num > 1
new_cfg = dict(coordinator=dict(
host='auto',
port='auto',
))
new_cfg['coordinator'].update(coordinator_cfg)
for i in range(learner_num):
new_cfg[f'learner{i}'] = dict(
type=cfg.system.comm_learner.type,
import_names=cfg.system.comm_learner.import_names,
host='auto',
port='auto',
path_data=path_data,
path_policy=path_policy,
multi_gpu=learner_multi_gpu,
gpu_num=learner_gpu_num,
)
for i in range(collector_num):
new_cfg[f'collector{i}'] = dict(
type=cfg.system.comm_collector.type,
import_names=cfg.system.comm_collector.import_names,
host='auto',
port='auto',
path_data=path_data,
path_policy=path_policy,
)
return EasyDict(new_cfg)
def parallel_transform(
cfg: dict,
coordinator_host: Optional[str] = None,
learner_host: Optional[List[str]] = None,
collector_host: Optional[List[str]] = None
) -> None:
coordinator_host = default_host if coordinator_host is None else coordinator_host
collector_host = default_host if collector_host is None else collector_host
learner_host = default_host if learner_host is None else learner_host
cfg = EasyDict(cfg)
cfg.system = set_system_cfg(cfg)
cfg.system = set_host_port(cfg.system, coordinator_host, learner_host, collector_host)
cfg.system = set_learner_interaction_for_coordinator(cfg.system)
cfg.system = set_collector_interaction_for_coordinator(cfg.system)
return cfg
def parallel_transform_slurm(
cfg: dict,
coordinator_host: Optional[str] = None,
learner_node: Optional[List[str]] = None,
collector_node: Optional[List[str]] = None
) -> None:
cfg = EasyDict(cfg)
cfg.system = set_system_cfg(cfg)
cfg.system = set_host_port_slurm(cfg.system, coordinator_host, learner_node, collector_node)
cfg.system = set_learner_interaction_for_coordinator(cfg.system)
cfg.system = set_collector_interaction_for_coordinator(cfg.system)
pretty_print(cfg)
return cfg
def parallel_transform_k8s(
cfg: dict,
coordinator_port: Optional[int] = None,
learner_port: Optional[int] = None,
collector_port: Optional[int] = None
) -> None:
cfg = EasyDict(cfg)
cfg.system = set_system_cfg(cfg)
cfg.system = set_host_port_k8s(cfg.system, coordinator_port, learner_port, collector_port)
# learner/collector is created by opereator, so the following field is placeholder
cfg.system.coordinator.collector = {}
cfg.system.coordinator.learner = {}
pretty_print(cfg)
return cfg
def save_config_formatted(config_: dict, path: str = 'formatted_total_config.py') -> None:
"""
Overview:
save formatted configuration to python file that can be read by serial_pipeline directly.
Arguments:
- config (:obj:`dict`): Config dict
- path (:obj:`str`): Path of python file
"""
with open(path, "w") as f:
f.write('from easydict import EasyDict\n\n')
f.write('main_config = dict(\n')
f.write(" exp_name='{}',\n".format(config_.exp_name))
for k, v in config_.items():
if (k == 'env'):
f.write(' env=dict(\n')
for k2, v2 in v.items():
if (k2 != 'type' and k2 != 'import_names' and k2 != 'manager'):
if (isinstance(v2, str)):
f.write(" {}='{}',\n".format(k2, v2))
else:
f.write(" {}={},\n".format(k2, v2))
if (k2 == 'manager'):
f.write(" manager=dict(\n")
for k3, v3 in v2.items():
if (v3 != 'cfg_type' and v3 != 'type'):
if (isinstance(v3, str)):
f.write(" {}='{}',\n".format(k3, v3))
elif v3 == float('inf'):
f.write(" {}=float('{}'),\n".format(k3, v3))
else:
f.write(" {}={},\n".format(k3, v3))
f.write(" ),\n")
f.write(" ),\n")
if (k == 'policy'):
f.write(' policy=dict(\n')
for k2, v2 in v.items():
if (k2 != 'type' and k2 != 'learn' and k2 != 'collect' and k2 != 'eval' and k2 != 'other'
and k2 != 'model'):
if (isinstance(v2, str)):
f.write(" {}='{}',\n".format(k2, v2))
else:
f.write(" {}={},\n".format(k2, v2))
elif (k2 == 'learn'):
f.write(" learn=dict(\n")
for k3, v3 in v2.items():
if (k3 != 'learner'):
if (isinstance(v3, str)):
f.write(" {}='{}',\n".format(k3, v3))
else:
f.write(" {}={},\n".format(k3, v3))
if (k3 == 'learner'):
f.write(" learner=dict(\n")
for k4, v4 in v3.items():
if (k4 != 'dataloader' and k4 != 'hook'):
if (isinstance(v4, str)):
f.write(" {}='{}',\n".format(k4, v4))
else:
f.write(" {}={},\n".format(k4, v4))
else:
if (k4 == 'dataloader'):
f.write(" dataloader=dict(\n")
for k5, v5 in v4.items():
if (isinstance(v5, str)):
f.write(" {}='{}',\n".format(k5, v5))
else:
f.write(" {}={},\n".format(k5, v5))
f.write(" ),\n")
if (k4 == 'hook'):
f.write(" hook=dict(\n")
for k5, v5 in v4.items():
if (isinstance(v5, str)):
f.write(" {}='{}',\n".format(k5, v5))
else:
f.write(" {}={},\n".format(k5, v5))
f.write(" ),\n")
f.write(" ),\n")
f.write(" ),\n")
elif (k2 == 'collect'):
f.write(" collect=dict(\n")
for k3, v3 in v2.items():
if (k3 != 'collector'):
if (isinstance(v3, str)):
f.write(" {}='{}',\n".format(k3, v3))
else:
f.write(" {}={},\n".format(k3, v3))
if (k3 == 'collector'):
f.write(" collector=dict(\n")
for k4, v4 in v3.items():
if (isinstance(v4, str)):
f.write(" {}='{}',\n".format(k4, v4))
else:
f.write(" {}={},\n".format(k4, v4))
f.write(" ),\n")
f.write(" ),\n")
elif (k2 == 'eval'):
f.write(" eval=dict(\n")
for k3, v3 in v2.items():
if (k3 != 'evaluator'):
if (isinstance(v3, str)):
f.write(" {}='{}',\n".format(k3, v3))
else:
f.write(" {}={},\n".format(k3, v3))
if (k3 == 'evaluator'):
f.write(" evaluator=dict(\n")
for k4, v4 in v3.items():
if (isinstance(v4, str)):
f.write(" {}='{}',\n".format(k4, v4))
else:
f.write(" {}={},\n".format(k4, v4))
f.write(" ),\n")
f.write(" ),\n")
elif (k2 == 'model'):
f.write(" model=dict(\n")
for k3, v3 in v2.items():
if (isinstance(v3, str)):
f.write(" {}='{}',\n".format(k3, v3))
else:
f.write(" {}={},\n".format(k3, v3))
f.write(" ),\n")
elif (k2 == 'other'):
f.write(" other=dict(\n")
for k3, v3 in v2.items():
if (k3 == 'replay_buffer'):
f.write(" replay_buffer=dict(\n")
for k4, v4 in v3.items():
if (k4 != 'monitor' and k4 != 'thruput_controller'):
if (isinstance(v4, dict)):
f.write(" {}=dict(\n".format(k4))
for k5, v5 in v4.items():
if (isinstance(v5, str)):
f.write(" {}='{}',\n".format(k5, v5))
elif v5 == float('inf'):
f.write(" {}=float('{}'),\n".format(k5, v5))
elif (isinstance(v5, dict)):
f.write(" {}=dict(\n".format(k5))
for k6, v6 in v5.items():
if (isinstance(v6, str)):
f.write(" {}='{}',\n".format(k6, v6))
elif v6 == float('inf'):
f.write(
" {}=float('{}'),\n".format(
k6, v6
)
)
elif (isinstance(v6, dict)):
f.write(" {}=dict(\n".format(k6))
for k7, v7 in v6.items():
if (isinstance(v7, str)):
f.write(
" {}='{}',\n".format(
k7, v7
)
)
elif v7 == float('inf'):
f.write(
" {}=float('{}'),\n".
format(k7, v7)
)
else:
f.write(
" {}={},\n".format(
k7, v7
)
)
f.write(" ),\n")
else:
f.write(" {}={},\n".format(k6, v6))
f.write(" ),\n")
else:
f.write(" {}={},\n".format(k5, v5))
f.write(" ),\n")
else:
if (isinstance(v4, str)):
f.write(" {}='{}',\n".format(k4, v4))
elif v4 == float('inf'):
f.write(" {}=float('{}'),\n".format(k4, v4))
else:
f.write(" {}={},\n".format(k4, v4))
else:
if (k4 == 'monitor'):
f.write(" monitor=dict(\n")
for k5, v5 in v4.items():
if (k5 == 'log_path'):
if (isinstance(v5, str)):
f.write(" {}='{}',\n".format(k5, v5))
else:
f.write(" {}={},\n".format(k5, v5))
else:
f.write(" {}=dict(\n".format(k5))
for k6, v6 in v5.items():
if (isinstance(v6, str)):
f.write(" {}='{}',\n".format(k6, v6))
else:
f.write(" {}={},\n".format(k6, v6))
f.write(" ),\n")
f.write(" ),\n")
if (k4 == 'thruput_controller'):
f.write(" thruput_controller=dict(\n")
for k5, v5 in v4.items():
if (isinstance(v5, dict)):
f.write(" {}=dict(\n".format(k5))
for k6, v6 in v5.items():
if (isinstance(v6, str)):
f.write(" {}='{}',\n".format(k6, v6))
elif v6 == float('inf'):
f.write(
" {}=float('{}'),\n".format(
k6, v6
)
)
else:
f.write(" {}={},\n".format(k6, v6))
f.write(" ),\n")
else:
if (isinstance(v5, str)):
f.write(" {}='{}',\n".format(k5, v5))
else:
f.write(" {}={},\n".format(k5, v5))
f.write(" ),\n")
f.write(" ),\n")
f.write(" ),\n")
f.write(" ),\n)\n")
f.write('main_config = EasyDict(main_config)\n')
f.write('main_config = main_config\n')
f.write('create_config = dict(\n')
for k, v in config_.items():
if (k == 'env'):
f.write(' env=dict(\n')
for k2, v2 in v.items():
if (k2 == 'type' or k2 == 'import_names'):
if isinstance(v2, str):
f.write(" {}='{}',\n".format(k2, v2))
else:
f.write(" {}={},\n".format(k2, v2))
f.write(" ),\n")
for k2, v2 in v.items():
if (k2 == 'manager'):
f.write(' env_manager=dict(\n')
for k3, v3 in v2.items():
if (k3 == 'cfg_type' or k3 == 'type'):
if (isinstance(v3, str)):
f.write(" {}='{}',\n".format(k3, v3))
else:
f.write(" {}={},\n".format(k3, v3))
f.write(" ),\n")
policy_type = config_.policy.type
if '_command' in policy_type:
f.write(" policy=dict(type='{}'),\n".format(policy_type[0:len(policy_type) - 8]))
else:
f.write(" policy=dict(type='{}'),\n".format(policy_type))
f.write(")\n")
f.write('create_config = EasyDict(create_config)\n')
f.write('create_config = create_config\n')
parallel_test_main_config = cartpole_dqn_config
parallel_test_create_config = dict(
env=dict(
type='cartpole',
import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn_command'),
comm_learner=dict(
type='flask_fs',
import_names=['ding.worker.learner.comm.flask_fs_learner'],
),
comm_collector=dict(
type='flask_fs',
import_names=['ding.worker.collector.comm.flask_fs_collector'],
),
learner=dict(
type='base',
import_names=['ding.worker.learner.base_learner'],
),
collector=dict(
type='zergling',
import_names=['ding.worker.collector.zergling_parallel_collector'],
),
commander=dict(
type='naive',
import_names=['ding.worker.coordinator.base_parallel_commander'],
),
)
parallel_test_create_config = EasyDict(parallel_test_create_config)
parallel_test_system_config = dict(
coordinator=dict(),
path_data='.',
path_policy='.',
communication_mode='auto',
learner_gpu_num=1,
)
parallel_test_system_config = EasyDict(parallel_test_system_config)