import gym import torch from easydict import EasyDict from ding.config import compile_config from ding.envs import DingEnvWrapper from ding.model import DQN from ding.policy import DQNPolicy, single_env_forward_wrapper from dizoo.cliffwalking.config.cliffwalking_dqn_config import create_config, main_config from dizoo.cliffwalking.envs.cliffwalking_env import CliffWalkingEnv def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str): main_config.exp_name = f'cliffwalking_dqn_seed0_deploy' cfg = compile_config(main_config, create_cfg=create_config, auto=True) env = CliffWalkingEnv(cfg.env) env.enable_save_replay(replay_path=f'./{main_config.exp_name}/video') model = DQN(**cfg.policy.model) state_dict = torch.load(ckpt_path, map_location='cpu') model.load_state_dict(state_dict['model']) policy = DQNPolicy(cfg.policy, model=model).eval_mode forward_fn = single_env_forward_wrapper(policy.forward) obs = env.reset() returns = 0. while True: action = forward_fn(obs) obs, rew, done, info = env.step(action) returns += rew if done: break print(f'Deploy is finished, final epsiode return is: {returns}') if __name__ == "__main__": main( main_config=main_config, create_config=create_config, ckpt_path=f'./cliffwalking_dqn_seed0/ckpt/ckpt_best.pth.tar' )