|
import pytest |
|
import time |
|
import os |
|
import torch |
|
import subprocess |
|
from copy import deepcopy |
|
|
|
from ding.entry import serial_pipeline, serial_pipeline_offline, collect_demo_data, serial_pipeline_onpolicy |
|
from ding.entry.serial_entry_sqil import serial_pipeline_sqil |
|
from dizoo.classic_control.cartpole.config.cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_pg_config import cartpole_pg_config, cartpole_pg_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_a2c_config import cartpole_a2c_config, cartpole_a2c_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_rainbow_config import cartpole_rainbow_config, cartpole_rainbow_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_iqn_config import cartpole_iqn_config, cartpole_iqn_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import cartpole_qrdqn_config, cartpole_qrdqn_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_acer_config import cartpole_acer_config, cartpole_acer_create_config |
|
from dizoo.classic_control.cartpole.entry.cartpole_ppg_main import main as ppg_main |
|
from dizoo.classic_control.cartpole.entry.cartpole_ppo_main import main as ppo_main |
|
from dizoo.classic_control.cartpole.config.cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config |
|
from dizoo.classic_control.pendulum.config import pendulum_ddpg_config, pendulum_ddpg_create_config |
|
from dizoo.classic_control.pendulum.config import pendulum_td3_config, pendulum_td3_create_config |
|
from dizoo.classic_control.pendulum.config import pendulum_sac_config, pendulum_sac_create_config |
|
from dizoo.bitflip.config import bitflip_her_dqn_config, bitflip_her_dqn_create_config |
|
from dizoo.bitflip.entry.bitflip_dqn_main import main as bitflip_dqn_main |
|
from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config |
|
from dizoo.league_demo.selfplay_demo_ppo_main import main as selfplay_main |
|
from dizoo.league_demo.league_demo_ppo_main import main as league_main |
|
from dizoo.classic_control.pendulum.config.pendulum_sac_data_generation_config import pendulum_sac_data_genearation_config, pendulum_sac_data_genearation_create_config |
|
from dizoo.classic_control.pendulum.config.pendulum_cql_config import pendulum_cql_config, pendulum_cql_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import cartpole_qrdqn_generation_data_config, cartpole_qrdqn_generation_data_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_cql_config import cartpole_discrete_cql_config, cartpole_discrete_cql_create_config |
|
from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config |
|
from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config |
|
from dizoo.petting_zoo.config import ptz_simple_spread_atoc_config, ptz_simple_spread_atoc_create_config |
|
from dizoo.petting_zoo.config import ptz_simple_spread_atoc_config, ptz_simple_spread_collaq_config, ptz_simple_spread_collaq_create_config |
|
from dizoo.petting_zoo.config import ptz_simple_spread_coma_config, ptz_simple_spread_coma_create_config |
|
from dizoo.petting_zoo.config import ptz_simple_spread_qmix_config, ptz_simple_spread_qmix_create_config |
|
from dizoo.petting_zoo.config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config |
|
from dizoo.petting_zoo.config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config |
|
from dizoo.petting_zoo.config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config |
|
from dizoo.classic_control.cartpole.config import cartpole_mdqn_config, cartpole_mdqn_create_config |
|
|
|
with open("./algo_record.log", "w+") as f: |
|
f.write("ALGO TEST STARTS\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_dqn(): |
|
config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("1. dqn\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_ddpg(): |
|
config = [deepcopy(pendulum_ddpg_config), deepcopy(pendulum_ddpg_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("2. ddpg\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_td3(): |
|
config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("3. td3\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_a2c(): |
|
config = [deepcopy(cartpole_a2c_config), deepcopy(cartpole_a2c_create_config)] |
|
try: |
|
serial_pipeline_onpolicy(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("4. a2c\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_rainbow(): |
|
config = [deepcopy(cartpole_rainbow_config), deepcopy(cartpole_rainbow_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("5. rainbow\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_ppo(): |
|
config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)] |
|
try: |
|
ppo_main(config[0], seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("6. ppo\n") |
|
|
|
|
|
|
|
def test_collaq(): |
|
config = [deepcopy(ptz_simple_spread_collaq_config), deepcopy(ptz_simple_spread_collaq_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("7. collaq\n") |
|
|
|
|
|
|
|
def test_coma(): |
|
config = [deepcopy(ptz_simple_spread_coma_config), deepcopy(ptz_simple_spread_coma_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("8. coma\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_sac(): |
|
config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("9. sac\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_c51(): |
|
config = [deepcopy(cartpole_c51_config), deepcopy(cartpole_c51_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("10. c51\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_r2d2(): |
|
config = [deepcopy(cartpole_r2d2_config), deepcopy(cartpole_r2d2_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("11. r2d2\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_pg(): |
|
config = [deepcopy(cartpole_pg_config), deepcopy(cartpole_pg_create_config)] |
|
try: |
|
serial_pipeline_onpolicy(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("12. pg\n") |
|
|
|
|
|
|
|
def test_atoc(): |
|
config = [deepcopy(ptz_simple_spread_atoc_config), deepcopy(ptz_simple_spread_atoc_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("13. atoc\n") |
|
|
|
|
|
|
|
def test_vdn(): |
|
config = [deepcopy(ptz_simple_spread_vdn_config), deepcopy(ptz_simple_spread_vdn_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("14. vdn\n") |
|
|
|
|
|
|
|
def test_qmix(): |
|
config = [deepcopy(ptz_simple_spread_qmix_config), deepcopy(ptz_simple_spread_qmix_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("15. qmix\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_impala(): |
|
config = [deepcopy(cartpole_impala_config), deepcopy(cartpole_impala_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("16. impala\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_iqn(): |
|
config = [deepcopy(cartpole_iqn_config), deepcopy(cartpole_iqn_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("17. iqn\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_her_dqn(): |
|
try: |
|
bitflip_her_dqn_config.exp_name = 'bitflip5_dqn' |
|
bitflip_her_dqn_config.env.n_bits = 5 |
|
bitflip_her_dqn_config.policy.model.obs_shape = 10 |
|
bitflip_her_dqn_config.policy.model.action_shape = 5 |
|
bitflip_dqn_main(bitflip_her_dqn_config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("18. her dqn\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_ppg(): |
|
try: |
|
ppg_main(cartpole_ppg_config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("19. ppg\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_sqn(): |
|
config = [deepcopy(cartpole_sqn_config), deepcopy(cartpole_sqn_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("20. sqn\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_qrdqn(): |
|
config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("21. qrdqn\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_acer(): |
|
config = [deepcopy(cartpole_acer_config), deepcopy(cartpole_acer_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("22. acer\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_selfplay(): |
|
try: |
|
selfplay_main(deepcopy(league_demo_ppo_config), seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("23. selfplay\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_league(): |
|
try: |
|
league_main(deepcopy(league_demo_ppo_config), seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("24. league\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_sqil(): |
|
expert_policy_state_dict_path = './expert_policy.pth' |
|
config = [deepcopy(cartpole_sql_config), deepcopy(cartpole_sql_create_config)] |
|
expert_policy = serial_pipeline(config, seed=0) |
|
torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path) |
|
|
|
config = [deepcopy(cartpole_sqil_config), deepcopy(cartpole_sqil_create_config)] |
|
config[0].policy.collect.model_path = expert_policy_state_dict_path |
|
try: |
|
serial_pipeline_sqil(config, [cartpole_sql_config, cartpole_sql_create_config], seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("25. sqil\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_cql(): |
|
|
|
config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)] |
|
config[0].exp_name = 'sac' |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
|
|
|
|
import torch |
|
config = [deepcopy(pendulum_sac_data_genearation_config), deepcopy(pendulum_sac_data_genearation_create_config)] |
|
collect_count = config[0].policy.collect.n_sample |
|
expert_data_path = config[0].policy.collect.save_path |
|
state_dict = torch.load('./sac/ckpt/ckpt_best.pth.tar', map_location='cpu') |
|
try: |
|
collect_demo_data( |
|
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict |
|
) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
|
|
|
|
config = [deepcopy(pendulum_cql_config), deepcopy(pendulum_cql_create_config)] |
|
try: |
|
serial_pipeline_offline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("26. cql\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_discrete_cql(): |
|
|
|
config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)] |
|
config[0].exp_name = 'cartpole' |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
|
|
|
|
import torch |
|
config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)] |
|
collect_count = config[0].policy.collect.collect_count |
|
state_dict = torch.load('cartpole/ckpt/ckpt_best.pth.tar', map_location='cpu') |
|
try: |
|
collect_demo_data(config, seed=0, collect_count=collect_count, state_dict=state_dict) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
|
|
|
|
config = [deepcopy(cartpole_discrete_cql_config), deepcopy(cartpole_discrete_cql_create_config)] |
|
try: |
|
serial_pipeline_offline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("27. discrete cql\n") |
|
|
|
|
|
|
|
def test_wqmix(): |
|
config = [deepcopy(ptz_simple_spread_wqmix_config), deepcopy(ptz_simple_spread_wqmix_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("28. wqmix\n") |
|
|
|
|
|
@pytest.mark.algotest |
|
def test_mdqn(): |
|
config = [deepcopy(cartpole_mdqn_config), deepcopy(cartpole_mdqn_create_config)] |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("29. mdqn\n") |
|
|
|
|
|
|
|
def test_td3_bc(): |
|
|
|
config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)] |
|
config[0].exp_name = 'td3' |
|
try: |
|
serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
|
|
|
|
import torch |
|
config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)] |
|
collect_count = config[0].policy.other.replay_buffer.replay_buffer_size |
|
expert_data_path = config[0].policy.collect.save_path |
|
state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') |
|
try: |
|
collect_demo_data( |
|
config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict |
|
) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
|
|
|
|
config = [deepcopy(pendulum_td3_bc_config), deepcopy(pendulum_td3_bc_create_config)] |
|
try: |
|
serial_pipeline_offline(config, seed=0) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
with open("./algo_record.log", "a+") as f: |
|
f.write("29. td3_bc\n") |
|
|
|
|
|
|
|
def test_running_on_orchestrator(): |
|
from kubernetes import config, client, dynamic |
|
from ding.utils import K8sLauncher, OrchestratorLauncher |
|
cluster_name = 'test-k8s-launcher' |
|
config_path = os.path.join(os.path.dirname(__file__), 'config', 'k8s-config.yaml') |
|
|
|
launcher = K8sLauncher(config_path) |
|
launcher.name = cluster_name |
|
launcher.create_cluster() |
|
|
|
|
|
olauncher = OrchestratorLauncher('v0.2.0-rc.0', cluster=launcher) |
|
olauncher.create_orchestrator() |
|
|
|
|
|
namespace = 'default' |
|
name = 'cartpole-dqn' |
|
timeout = 20 * 60 |
|
file_path = os.path.dirname(__file__) |
|
agconfig_path = os.path.join(file_path, 'config', 'agconfig.yaml') |
|
dijob_path = os.path.join(file_path, 'config', 'dijob-cartpole.yaml') |
|
create_object_from_config(agconfig_path, 'di-system') |
|
create_object_from_config(dijob_path, namespace) |
|
|
|
|
|
config.load_kube_config() |
|
dyclient = dynamic.DynamicClient(client.ApiClient(configuration=config.load_kube_config())) |
|
dijobapi = dyclient.resources.get(api_version='diengine.opendilab.org/v1alpha1', kind='DIJob') |
|
|
|
wait_for_dijob_condition(dijobapi, name, namespace, 'Succeeded', timeout) |
|
|
|
v1 = client.CoreV1Api() |
|
logs = v1.read_namespaced_pod_log(f'{name}-coordinator', namespace, tail_lines=20) |
|
print(f'\ncoordinator logs:\n {logs} \n') |
|
|
|
|
|
dijobapi.delete(name=name, namespace=namespace, body={}) |
|
|
|
olauncher.delete_orchestrator() |
|
|
|
launcher.delete_cluster() |
|
|
|
|
|
def create_object_from_config(config_path: str, namespace: str = 'default'): |
|
args = ['kubectl', 'apply', '-n', namespace, '-f', config_path] |
|
proc = subprocess.Popen(args, stderr=subprocess.PIPE) |
|
_, err = proc.communicate() |
|
err_str = err.decode('utf-8').strip() |
|
if err_str != '' and 'WARN' not in err_str and 'already exists' not in err_str: |
|
raise RuntimeError(f'Failed to create object: {err_str}') |
|
|
|
|
|
def delete_object_from_config(config_path: str, namespace: str = 'default'): |
|
args = ['kubectl', 'delete', '-n', namespace, '-f', config_path] |
|
proc = subprocess.Popen(args, stderr=subprocess.PIPE) |
|
_, err = proc.communicate() |
|
err_str = err.decode('utf-8').strip() |
|
if err_str != '' and 'WARN' not in err_str and 'NotFound' not in err_str: |
|
raise RuntimeError(f'Failed to delete object: {err_str}') |
|
|
|
|
|
def wait_for_dijob_condition(dijobapi, name: str, namespace: str, phase: str, timeout: int = 60, interval: int = 1): |
|
start = time.time() |
|
dijob = dijobapi.get(name=name, namespace=namespace) |
|
while (dijob.status is None or dijob.status.phase != phase) and time.time() - start < timeout: |
|
time.sleep(interval) |
|
dijob = dijobapi.get(name=name, namespace=namespace) |
|
|
|
if dijob.status.phase == phase: |
|
return |
|
raise TimeoutError(f'Timeout waiting for DIJob: {name} to be {phase}') |
|
|