import pytest import time import os import torch import subprocess from copy import deepcopy from ding.entry import serial_pipeline, serial_pipeline_offline, collect_demo_data, serial_pipeline_onpolicy from ding.entry.serial_entry_sqil import serial_pipeline_sqil from dizoo.classic_control.cartpole.config.cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config from dizoo.classic_control.cartpole.config.cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config from dizoo.classic_control.cartpole.config.cartpole_pg_config import cartpole_pg_config, cartpole_pg_create_config from dizoo.classic_control.cartpole.config.cartpole_a2c_config import cartpole_a2c_config, cartpole_a2c_create_config from dizoo.classic_control.cartpole.config.cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_rainbow_config import cartpole_rainbow_config, cartpole_rainbow_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_iqn_config import cartpole_iqn_config, cartpole_iqn_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import cartpole_qrdqn_config, cartpole_qrdqn_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_acer_config import cartpole_acer_config, cartpole_acer_create_config # noqa from dizoo.classic_control.cartpole.entry.cartpole_ppg_main import main as ppg_main from dizoo.classic_control.cartpole.entry.cartpole_ppo_main import main as ppo_main from dizoo.classic_control.cartpole.config.cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config # noqa from dizoo.classic_control.pendulum.config import pendulum_ddpg_config, pendulum_ddpg_create_config from dizoo.classic_control.pendulum.config import pendulum_td3_config, pendulum_td3_create_config from dizoo.classic_control.pendulum.config import pendulum_sac_config, pendulum_sac_create_config from dizoo.bitflip.config import bitflip_her_dqn_config, bitflip_her_dqn_create_config from dizoo.bitflip.entry.bitflip_dqn_main import main as bitflip_dqn_main from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config from dizoo.league_demo.selfplay_demo_ppo_main import main as selfplay_main from dizoo.league_demo.league_demo_ppo_main import main as league_main from dizoo.classic_control.pendulum.config.pendulum_sac_data_generation_config import pendulum_sac_data_genearation_config, pendulum_sac_data_genearation_create_config # noqa from dizoo.classic_control.pendulum.config.pendulum_cql_config import pendulum_cql_config, pendulum_cql_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import cartpole_qrdqn_generation_data_config, cartpole_qrdqn_generation_data_create_config # noqa from dizoo.classic_control.cartpole.config.cartpole_cql_config import cartpole_discrete_cql_config, cartpole_discrete_cql_create_config # noqa from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config # noqa from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config # noqa from dizoo.petting_zoo.config import ptz_simple_spread_atoc_config, ptz_simple_spread_atoc_create_config # noqa from dizoo.petting_zoo.config import ptz_simple_spread_atoc_config, ptz_simple_spread_collaq_config, ptz_simple_spread_collaq_create_config # noqa from dizoo.petting_zoo.config import ptz_simple_spread_coma_config, ptz_simple_spread_coma_create_config # noqa from dizoo.petting_zoo.config import ptz_simple_spread_qmix_config, ptz_simple_spread_qmix_create_config # noqa from dizoo.petting_zoo.config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config # noqa from dizoo.petting_zoo.config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config # noqa from dizoo.petting_zoo.config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config # noqa from dizoo.classic_control.cartpole.config import cartpole_mdqn_config, cartpole_mdqn_create_config with open("./algo_record.log", "w+") as f: f.write("ALGO TEST STARTS\n") @pytest.mark.algotest def test_dqn(): config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("1. dqn\n") @pytest.mark.algotest def test_ddpg(): config = [deepcopy(pendulum_ddpg_config), deepcopy(pendulum_ddpg_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("2. ddpg\n") @pytest.mark.algotest def test_td3(): config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("3. td3\n") @pytest.mark.algotest def test_a2c(): config = [deepcopy(cartpole_a2c_config), deepcopy(cartpole_a2c_create_config)] try: serial_pipeline_onpolicy(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("4. a2c\n") @pytest.mark.algotest def test_rainbow(): config = [deepcopy(cartpole_rainbow_config), deepcopy(cartpole_rainbow_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("5. rainbow\n") @pytest.mark.algotest def test_ppo(): config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)] try: ppo_main(config[0], seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("6. ppo\n") # @pytest.mark.algotest def test_collaq(): config = [deepcopy(ptz_simple_spread_collaq_config), deepcopy(ptz_simple_spread_collaq_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("7. collaq\n") # @pytest.mark.algotest def test_coma(): config = [deepcopy(ptz_simple_spread_coma_config), deepcopy(ptz_simple_spread_coma_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("8. coma\n") @pytest.mark.algotest def test_sac(): config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("9. sac\n") @pytest.mark.algotest def test_c51(): config = [deepcopy(cartpole_c51_config), deepcopy(cartpole_c51_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("10. c51\n") @pytest.mark.algotest def test_r2d2(): config = [deepcopy(cartpole_r2d2_config), deepcopy(cartpole_r2d2_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("11. r2d2\n") @pytest.mark.algotest def test_pg(): config = [deepcopy(cartpole_pg_config), deepcopy(cartpole_pg_create_config)] try: serial_pipeline_onpolicy(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("12. pg\n") # @pytest.mark.algotest def test_atoc(): config = [deepcopy(ptz_simple_spread_atoc_config), deepcopy(ptz_simple_spread_atoc_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("13. atoc\n") # @pytest.mark.algotest def test_vdn(): config = [deepcopy(ptz_simple_spread_vdn_config), deepcopy(ptz_simple_spread_vdn_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("14. vdn\n") # @pytest.mark.algotest def test_qmix(): config = [deepcopy(ptz_simple_spread_qmix_config), deepcopy(ptz_simple_spread_qmix_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("15. qmix\n") @pytest.mark.algotest def test_impala(): config = [deepcopy(cartpole_impala_config), deepcopy(cartpole_impala_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("16. impala\n") @pytest.mark.algotest def test_iqn(): config = [deepcopy(cartpole_iqn_config), deepcopy(cartpole_iqn_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("17. iqn\n") @pytest.mark.algotest def test_her_dqn(): try: bitflip_her_dqn_config.exp_name = 'bitflip5_dqn' bitflip_her_dqn_config.env.n_bits = 5 bitflip_her_dqn_config.policy.model.obs_shape = 10 bitflip_her_dqn_config.policy.model.action_shape = 5 bitflip_dqn_main(bitflip_her_dqn_config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("18. her dqn\n") @pytest.mark.algotest def test_ppg(): try: ppg_main(cartpole_ppg_config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("19. ppg\n") @pytest.mark.algotest def test_sqn(): config = [deepcopy(cartpole_sqn_config), deepcopy(cartpole_sqn_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("20. sqn\n") @pytest.mark.algotest def test_qrdqn(): config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("21. qrdqn\n") @pytest.mark.algotest def test_acer(): config = [deepcopy(cartpole_acer_config), deepcopy(cartpole_acer_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("22. acer\n") @pytest.mark.algotest def test_selfplay(): try: selfplay_main(deepcopy(league_demo_ppo_config), seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("23. selfplay\n") @pytest.mark.algotest def test_league(): try: league_main(deepcopy(league_demo_ppo_config), seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("24. league\n") @pytest.mark.algotest def test_sqil(): expert_policy_state_dict_path = './expert_policy.pth' config = [deepcopy(cartpole_sql_config), deepcopy(cartpole_sql_create_config)] expert_policy = serial_pipeline(config, seed=0) torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path) config = [deepcopy(cartpole_sqil_config), deepcopy(cartpole_sqil_create_config)] config[0].policy.collect.model_path = expert_policy_state_dict_path try: serial_pipeline_sqil(config, [cartpole_sql_config, cartpole_sql_create_config], seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("25. sqil\n") @pytest.mark.algotest def test_cql(): # train expert config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)] config[0].exp_name = 'sac' try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" # collect expert data import torch config = [deepcopy(pendulum_sac_data_genearation_config), deepcopy(pendulum_sac_data_genearation_create_config)] collect_count = config[0].policy.collect.n_sample expert_data_path = config[0].policy.collect.save_path state_dict = torch.load('./sac/ckpt/ckpt_best.pth.tar', map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict ) except Exception: assert False, "pipeline fail" # train cql config = [deepcopy(pendulum_cql_config), deepcopy(pendulum_cql_create_config)] try: serial_pipeline_offline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("26. cql\n") @pytest.mark.algotest def test_discrete_cql(): # train expert config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)] config[0].exp_name = 'cartpole' try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" # collect expert data import torch config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)] collect_count = config[0].policy.collect.collect_count state_dict = torch.load('cartpole/ckpt/ckpt_best.pth.tar', map_location='cpu') try: collect_demo_data(config, seed=0, collect_count=collect_count, state_dict=state_dict) except Exception: assert False, "pipeline fail" # train cql config = [deepcopy(cartpole_discrete_cql_config), deepcopy(cartpole_discrete_cql_create_config)] try: serial_pipeline_offline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("27. discrete cql\n") # @pytest.mark.algotest def test_wqmix(): config = [deepcopy(ptz_simple_spread_wqmix_config), deepcopy(ptz_simple_spread_wqmix_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("28. wqmix\n") @pytest.mark.algotest def test_mdqn(): config = [deepcopy(cartpole_mdqn_config), deepcopy(cartpole_mdqn_create_config)] try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("29. mdqn\n") # @pytest.mark.algotest def test_td3_bc(): # train expert config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)] config[0].exp_name = 'td3' try: serial_pipeline(config, seed=0) except Exception: assert False, "pipeline fail" # collect expert data import torch config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)] collect_count = config[0].policy.other.replay_buffer.replay_buffer_size expert_data_path = config[0].policy.collect.save_path state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') try: collect_demo_data( config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict ) except Exception: assert False, "pipeline fail" # train td3 bc config = [deepcopy(pendulum_td3_bc_config), deepcopy(pendulum_td3_bc_create_config)] try: serial_pipeline_offline(config, seed=0) except Exception: assert False, "pipeline fail" with open("./algo_record.log", "a+") as f: f.write("29. td3_bc\n") # @pytest.mark.algotest def test_running_on_orchestrator(): from kubernetes import config, client, dynamic from ding.utils import K8sLauncher, OrchestratorLauncher cluster_name = 'test-k8s-launcher' config_path = os.path.join(os.path.dirname(__file__), 'config', 'k8s-config.yaml') # create cluster launcher = K8sLauncher(config_path) launcher.name = cluster_name launcher.create_cluster() # create orchestrator olauncher = OrchestratorLauncher('v0.2.0-rc.0', cluster=launcher) olauncher.create_orchestrator() # create dijob namespace = 'default' name = 'cartpole-dqn' timeout = 20 * 60 file_path = os.path.dirname(__file__) agconfig_path = os.path.join(file_path, 'config', 'agconfig.yaml') dijob_path = os.path.join(file_path, 'config', 'dijob-cartpole.yaml') create_object_from_config(agconfig_path, 'di-system') create_object_from_config(dijob_path, namespace) # watch for dijob to converge config.load_kube_config() dyclient = dynamic.DynamicClient(client.ApiClient(configuration=config.load_kube_config())) dijobapi = dyclient.resources.get(api_version='diengine.opendilab.org/v1alpha1', kind='DIJob') wait_for_dijob_condition(dijobapi, name, namespace, 'Succeeded', timeout) v1 = client.CoreV1Api() logs = v1.read_namespaced_pod_log(f'{name}-coordinator', namespace, tail_lines=20) print(f'\ncoordinator logs:\n {logs} \n') # delete dijob dijobapi.delete(name=name, namespace=namespace, body={}) # delete orchestrator olauncher.delete_orchestrator() # delete k8s cluster launcher.delete_cluster() def create_object_from_config(config_path: str, namespace: str = 'default'): args = ['kubectl', 'apply', '-n', namespace, '-f', config_path] proc = subprocess.Popen(args, stderr=subprocess.PIPE) _, err = proc.communicate() err_str = err.decode('utf-8').strip() if err_str != '' and 'WARN' not in err_str and 'already exists' not in err_str: raise RuntimeError(f'Failed to create object: {err_str}') def delete_object_from_config(config_path: str, namespace: str = 'default'): args = ['kubectl', 'delete', '-n', namespace, '-f', config_path] proc = subprocess.Popen(args, stderr=subprocess.PIPE) _, err = proc.communicate() err_str = err.decode('utf-8').strip() if err_str != '' and 'WARN' not in err_str and 'NotFound' not in err_str: raise RuntimeError(f'Failed to delete object: {err_str}') def wait_for_dijob_condition(dijobapi, name: str, namespace: str, phase: str, timeout: int = 60, interval: int = 1): start = time.time() dijob = dijobapi.get(name=name, namespace=namespace) while (dijob.status is None or dijob.status.phase != phase) and time.time() - start < timeout: time.sleep(interval) dijob = dijobapi.get(name=name, namespace=namespace) if dijob.status.phase == phase: return raise TimeoutError(f'Timeout waiting for DIJob: {name} to be {phase}')