import ding.config from .a2c import A2CAgent from .c51 import C51Agent from .ddpg import DDPGAgent from .dqn import DQNAgent from .pg import PGAgent from .ppof import PPOF from .ppo_offpolicy import PPOOffPolicyAgent from .sac import SACAgent from .sql import SQLAgent from .td3 import TD3Agent supported_algo = dict( A2C=A2CAgent, C51=C51Agent, DDPG=DDPGAgent, DQN=DQNAgent, PG=PGAgent, PPOF=PPOF, PPOOffPolicy=PPOOffPolicyAgent, SAC=SACAgent, SQL=SQLAgent, TD3=TD3Agent, ) supported_algo_list = list(supported_algo.keys()) def env_supported(algo: str = None) -> list: """ return list of the envs that supported by di-engine. """ if algo is not None: if algo.upper() == "A2C": return list(ding.config.example.A2C.supported_env.keys()) elif algo.upper() == "C51": return list(ding.config.example.C51.supported_env.keys()) elif algo.upper() == "DDPG": return list(ding.config.example.DDPG.supported_env.keys()) elif algo.upper() == "DQN": return list(ding.config.example.DQN.supported_env.keys()) elif algo.upper() == "PG": return list(ding.config.example.PG.supported_env.keys()) elif algo.upper() == "PPOF": return list(ding.config.example.PPOF.supported_env.keys()) elif algo.upper() == "PPOOFFPOLICY": return list(ding.config.example.PPOOffPolicy.supported_env.keys()) elif algo.upper() == "SAC": return list(ding.config.example.SAC.supported_env.keys()) elif algo.upper() == "SQL": return list(ding.config.example.SQL.supported_env.keys()) elif algo.upper() == "TD3": return list(ding.config.example.TD3.supported_env.keys()) else: raise ValueError("The algo {} is not supported by di-engine.".format(algo)) else: supported_env = set() supported_env.update(ding.config.example.A2C.supported_env.keys()) supported_env.update(ding.config.example.C51.supported_env.keys()) supported_env.update(ding.config.example.DDPG.supported_env.keys()) supported_env.update(ding.config.example.DQN.supported_env.keys()) supported_env.update(ding.config.example.PG.supported_env.keys()) supported_env.update(ding.config.example.PPOF.supported_env.keys()) supported_env.update(ding.config.example.PPOOffPolicy.supported_env.keys()) supported_env.update(ding.config.example.SAC.supported_env.keys()) supported_env.update(ding.config.example.SQL.supported_env.keys()) supported_env.update(ding.config.example.TD3.supported_env.keys()) # return the list of the envs return list(supported_env) supported_env = env_supported() def algo_supported(env_id: str = None) -> list: """ return list of the algos that supported by di-engine. """ if env_id is not None: algo = [] if env_id.upper() in [item.upper() for item in ding.config.example.A2C.supported_env.keys()]: algo.append("A2C") if env_id.upper() in [item.upper() for item in ding.config.example.C51.supported_env.keys()]: algo.append("C51") if env_id.upper() in [item.upper() for item in ding.config.example.DDPG.supported_env.keys()]: algo.append("DDPG") if env_id.upper() in [item.upper() for item in ding.config.example.DQN.supported_env.keys()]: algo.append("DQN") if env_id.upper() in [item.upper() for item in ding.config.example.PG.supported_env.keys()]: algo.append("PG") if env_id.upper() in [item.upper() for item in ding.config.example.PPOF.supported_env.keys()]: algo.append("PPOF") if env_id.upper() in [item.upper() for item in ding.config.example.PPOOffPolicy.supported_env.keys()]: algo.append("PPOOffPolicy") if env_id.upper() in [item.upper() for item in ding.config.example.SAC.supported_env.keys()]: algo.append("SAC") if env_id.upper() in [item.upper() for item in ding.config.example.SQL.supported_env.keys()]: algo.append("SQL") if env_id.upper() in [item.upper() for item in ding.config.example.TD3.supported_env.keys()]: algo.append("TD3") if len(algo) == 0: raise ValueError("The env {} is not supported by di-engine.".format(env_id)) return algo else: return supported_algo_list def is_supported(env_id: str = None, algo: str = None) -> bool: """ Check if the env-algo pair is supported by di-engine. """ if env_id is not None and env_id.upper() in [item.upper() for item in supported_env.keys()]: if algo is not None and algo.upper() in supported_algo_list: if env_id.upper() in env_supported(algo): return True else: return False elif algo is None: return True else: return False elif env_id is None: if algo is not None and algo.upper() in supported_algo_list: return True elif algo is None: raise ValueError("Please specify the env or algo.") else: return False else: return False