File size: 1,405 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gym
import torch
from easydict import EasyDict

from ding.config import compile_config
from ding.envs import DingEnvWrapper
from ding.model import DQN
from ding.policy import DQNPolicy, single_env_forward_wrapper
from dizoo.cliffwalking.config.cliffwalking_dqn_config import create_config, main_config
from dizoo.cliffwalking.envs.cliffwalking_env import CliffWalkingEnv


def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str):
    main_config.exp_name = f'cliffwalking_dqn_seed0_deploy'
    cfg = compile_config(main_config, create_cfg=create_config, auto=True)
    env = CliffWalkingEnv(cfg.env)
    env.enable_save_replay(replay_path=f'./{main_config.exp_name}/video')
    model = DQN(**cfg.policy.model)
    state_dict = torch.load(ckpt_path, map_location='cpu')
    model.load_state_dict(state_dict['model'])
    policy = DQNPolicy(cfg.policy, model=model).eval_mode
    forward_fn = single_env_forward_wrapper(policy.forward)
    obs = env.reset()
    returns = 0.
    while True:
        action = forward_fn(obs)
        obs, rew, done, info = env.step(action)
        returns += rew
        if done:
            break
    print(f'Deploy is finished, final epsiode return is: {returns}')


if __name__ == "__main__":
    main(
        main_config=main_config,
        create_config=create_config,
        ckpt_path=f'./cliffwalking_dqn_seed0/ckpt/ckpt_best.pth.tar'
    )