zjowowen's picture
init space
079c32c
raw
history blame
No virus
2.61 kB
from typing import Union, Optional, List, Any, Callable, Tuple
import pickle
import torch
from functools import partial
from ding.config import compile_config, read_config
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.utils import set_pkg_seed
def eval(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
state_dict: Optional[dict] = None,
) -> float:
r"""
Overview:
Pure evaluation entry.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type += '_command'
cfg = compile_config(cfg, auto=True, create_cfg=create_cfg)
env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env)
env = env_fn(evaluator_env_cfg[0])
env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['eval']).eval_mode
if state_dict is None:
state_dict = torch.load(cfg.learner.load_path, map_location='cpu')
policy.load_state_dict(state_dict)
obs = env.reset()
episode_return = 0.
while True:
policy_output = policy.forward({0: obs})
action = policy_output[0]['action']
print(action)
timestep = env.step(action)
episode_return += timestep.reward
obs = timestep.obs
if timestep.done:
print(timestep.info)
break
env.save_replay(replay_dir='.', prefix=env._map_name)
print('Eval is over! The performance of your RL policy is {}'.format(episode_return))
if __name__ == "__main__":
path = '../exp/MMM/qmix/1/ckpt_BaseLearner_Wed_Jul_14_22_16_56_2021/iteration_9900.pth.tar'
cfg = '../config/smac_MMM_qmix_config.py'
state_dict = torch.load(path, map_location='cpu')
eval(cfg, seed=0, state_dict=state_dict)