|
import os |
|
from functools import partial |
|
from typing import Optional, Union, List |
|
|
|
import numpy as np |
|
import torch |
|
from ding.bonus.common import TrainingReturn, EvalReturn |
|
from ding.config import save_config_py, compile_config |
|
from ding.envs import create_env_manager |
|
from ding.envs import get_vec_env_setting |
|
from ding.policy import create_policy |
|
from ding.rl_utils import get_epsilon_greedy_fn |
|
from ding.utils import set_pkg_seed, get_rank |
|
from ding.worker import BaseLearner |
|
from ditk import logging |
|
from easydict import EasyDict |
|
from tensorboardX import SummaryWriter |
|
|
|
from lzero.agent.config.muzero import supported_env_cfg |
|
from lzero.entry.utils import log_buffer_memory_usage, random_collect |
|
from lzero.mcts import MuZeroGameBuffer |
|
from lzero.policy import visit_count_temperature |
|
from lzero.policy.muzero import MuZeroPolicy |
|
from lzero.policy.random_policy import LightZeroRandomPolicy |
|
from lzero.worker import MuZeroCollector as Collector |
|
from lzero.worker import MuZeroEvaluator as Evaluator |
|
|
|
|
|
class MuZeroAgent: |
|
""" |
|
Overview: |
|
Agent class for executing MuZero algorithms which include methods for training, deployment, and batch evaluation. |
|
Interfaces: |
|
__init__, train, deploy, batch_evaluate |
|
Properties: |
|
best |
|
|
|
.. note:: |
|
This agent class is tailored for use with the HuggingFace Model Zoo for LightZero |
|
(e.g. https://huggingface.co/OpenDILabCommunity/CartPole-v0-MuZero), |
|
and provides methods such as "train" and "deploy". |
|
""" |
|
|
|
supported_env_list = list(supported_env_cfg.keys()) |
|
|
|
def __init__( |
|
self, |
|
env_id: str = None, |
|
seed: int = 0, |
|
exp_name: str = None, |
|
model: Optional[torch.nn.Module] = None, |
|
cfg: Optional[Union[EasyDict, dict]] = None, |
|
policy_state_dict: str = None, |
|
) -> None: |
|
""" |
|
Overview: |
|
Initialize the MuZeroAgent instance with environment parameters, model, and configuration. |
|
Arguments: |
|
- env_id (:obj:`str`): Identifier for the environment to be used, registered in gym. |
|
- seed (:obj:`int`): Random seed for reproducibility. Defaults to 0. |
|
- exp_name (:obj:`Optional[str]`): Name for the experiment. Defaults to None. |
|
- model (:obj:`Optional[torch.nn.Module]`): PyTorch module to be used as the model. If None, a default model is created. Defaults to None. |
|
- cfg (:obj:`Optional[Union[EasyDict, dict]]`): Configuration for the agent. If None, default configuration will be used. Defaults to None. |
|
- policy_state_dict (:obj:`Optional[str]`): Path to a pre-trained model state dictionary. If provided, state dict will be loaded. Defaults to None. |
|
|
|
.. note:: |
|
- If `env_id` is not specified, it must be included in `cfg`. |
|
- The `supported_env_list` contains all the environment IDs that are supported by this agent. |
|
""" |
|
assert env_id is not None or cfg["main_config"]["env_id"] is not None, "Please specify env_id or cfg." |
|
|
|
if cfg is not None and not isinstance(cfg, EasyDict): |
|
cfg = EasyDict(cfg) |
|
|
|
if env_id is not None: |
|
assert env_id in MuZeroAgent.supported_env_list, "Please use supported envs: {}".format( |
|
MuZeroAgent.supported_env_list |
|
) |
|
if cfg is None: |
|
cfg = supported_env_cfg[env_id] |
|
else: |
|
assert cfg.main_config.env.env_id == env_id, "env_id in cfg should be the same as env_id in args." |
|
else: |
|
assert hasattr(cfg.main_config.env, "env_id"), "Please specify env_id in cfg." |
|
assert cfg.main_config.env.env_id in MuZeroAgent.supported_env_list, "Please use supported envs: {}".format( |
|
MuZeroAgent.supported_env_list |
|
) |
|
default_policy_config = EasyDict({"policy": MuZeroPolicy.default_config()}) |
|
default_policy_config.policy.update(cfg.main_config.policy) |
|
cfg.main_config.policy = default_policy_config.policy |
|
|
|
if exp_name is not None: |
|
cfg.main_config.exp_name = exp_name |
|
self.origin_cfg = cfg |
|
self.cfg = compile_config( |
|
cfg.main_config, seed=seed, env=None, auto=True, policy=MuZeroPolicy, create_cfg=cfg.create_config |
|
) |
|
self.exp_name = self.cfg.exp_name |
|
|
|
logging.getLogger().setLevel(logging.INFO) |
|
self.seed = seed |
|
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda) |
|
if not os.path.exists(self.exp_name): |
|
os.makedirs(self.exp_name) |
|
save_config_py(cfg, os.path.join(self.exp_name, 'policy_config.py')) |
|
if model is None: |
|
if self.cfg.policy.model.model_type == 'mlp': |
|
from lzero.model.muzero_model_mlp import MuZeroModelMLP |
|
model = MuZeroModelMLP(**self.cfg.policy.model) |
|
elif self.cfg.policy.model.model_type == 'conv': |
|
from lzero.model.muzero_model import MuZeroModel |
|
model = MuZeroModel(**self.cfg.policy.model) |
|
else: |
|
raise NotImplementedError |
|
if self.cfg.policy.cuda and torch.cuda.is_available(): |
|
self.cfg.policy.device = 'cuda' |
|
else: |
|
self.cfg.policy.device = 'cpu' |
|
self.policy = create_policy(self.cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) |
|
if policy_state_dict is not None: |
|
self.policy.learn_mode.load_state_dict(policy_state_dict) |
|
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt") |
|
|
|
self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env) |
|
|
|
def train( |
|
self, |
|
step: int = int(1e7), |
|
) -> TrainingReturn: |
|
""" |
|
Overview: |
|
Train the agent through interactions with the environment. |
|
Arguments: |
|
- step (:obj:`int`): Total number of environment steps to train for. Defaults to 10 million (1e7). |
|
Returns: |
|
- A `TrainingReturn` object containing training information, such as logs and potentially a URL to a training dashboard. |
|
.. note:: |
|
The method involves interacting with the environment, collecting experience, and optimizing the model. |
|
""" |
|
|
|
collector_env = create_env_manager( |
|
self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.collector_env_cfg] |
|
) |
|
evaluator_env = create_env_manager( |
|
self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] |
|
) |
|
|
|
collector_env.seed(self.cfg.seed) |
|
evaluator_env.seed(self.cfg.seed, dynamic_seed=False) |
|
set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) |
|
|
|
|
|
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(self.cfg.exp_name), 'serial') |
|
) if get_rank() == 0 else None |
|
learner = BaseLearner( |
|
self.cfg.policy.learn.learner, self.policy.learn_mode, tb_logger, exp_name=self.cfg.exp_name |
|
) |
|
|
|
|
|
|
|
|
|
policy_config = self.cfg.policy |
|
batch_size = policy_config.batch_size |
|
|
|
replay_buffer = MuZeroGameBuffer(policy_config) |
|
collector = Collector( |
|
env=collector_env, |
|
policy=self.policy.collect_mode, |
|
tb_logger=tb_logger, |
|
exp_name=self.cfg.exp_name, |
|
policy_config=policy_config |
|
) |
|
evaluator = Evaluator( |
|
eval_freq=self.cfg.policy.eval_freq, |
|
n_evaluator_episode=self.cfg.env.n_evaluator_episode, |
|
stop_value=self.cfg.env.stop_value, |
|
env=evaluator_env, |
|
policy=self.policy.eval_mode, |
|
tb_logger=tb_logger, |
|
exp_name=self.cfg.exp_name, |
|
policy_config=policy_config |
|
) |
|
|
|
|
|
|
|
|
|
|
|
learner.call_hook('before_run') |
|
|
|
if self.cfg.policy.update_per_collect is not None: |
|
update_per_collect = self.cfg.policy.update_per_collect |
|
|
|
|
|
|
|
|
|
if self.cfg.policy.random_collect_episode_num > 0: |
|
random_collect(self.cfg.policy, self.policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) |
|
|
|
while True: |
|
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) |
|
collect_kwargs = {} |
|
|
|
|
|
collect_kwargs['temperature'] = visit_count_temperature( |
|
policy_config.manual_temperature_decay, |
|
policy_config.fixed_temperature_value, |
|
policy_config.threshold_training_steps_for_final_temperature, |
|
trained_steps=learner.train_iter |
|
) |
|
|
|
if policy_config.eps.eps_greedy_exploration_in_collect: |
|
epsilon_greedy_fn = get_epsilon_greedy_fn( |
|
start=policy_config.eps.start, |
|
end=policy_config.eps.end, |
|
decay=policy_config.eps.decay, |
|
type_=policy_config.eps.type |
|
) |
|
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) |
|
else: |
|
collect_kwargs['epsilon'] = 0.0 |
|
|
|
|
|
if evaluator.should_eval(learner.train_iter): |
|
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) |
|
if stop: |
|
break |
|
|
|
|
|
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) |
|
if self.cfg.policy.update_per_collect is None: |
|
|
|
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) |
|
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) |
|
|
|
replay_buffer.push_game_segments(new_data) |
|
|
|
replay_buffer.remove_oldest_data_to_fit() |
|
|
|
|
|
for i in range(update_per_collect): |
|
|
|
if replay_buffer.get_num_of_transitions() > batch_size: |
|
train_data = replay_buffer.sample(batch_size, self.policy) |
|
else: |
|
logging.warning( |
|
f'The data in replay_buffer is not sufficient to sample a mini-batch: ' |
|
f'batch_size: {batch_size}, ' |
|
f'{replay_buffer} ' |
|
f'continue to collect now ....' |
|
) |
|
break |
|
|
|
|
|
log_vars = learner.train(train_data, collector.envstep) |
|
|
|
if self.cfg.policy.use_priority: |
|
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) |
|
|
|
if collector.envstep >= step: |
|
break |
|
|
|
|
|
learner.call_hook('after_run') |
|
|
|
return TrainingReturn(wandb_url=None) |
|
|
|
def deploy( |
|
self, |
|
enable_save_replay: bool = False, |
|
concatenate_all_replay: bool = False, |
|
replay_save_path: str = None, |
|
seed: Optional[Union[int, List]] = None, |
|
debug: bool = False |
|
) -> EvalReturn: |
|
""" |
|
Overview: |
|
Deploy the agent for evaluation in the environment, with optional replay saving. The performance of the |
|
agent will be evaluated. Average return and standard deviation of the return will be returned. |
|
If `enable_save_replay` is True, replay videos are saved in the specified `replay_save_path`. |
|
Arguments: |
|
- enable_save_replay (:obj:`bool`): Flag to enable saving of replay footage. Defaults to False. |
|
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one file. Defaults to False. |
|
- replay_save_path (:obj:`Optional[str]`): Directory path to save replay videos. Defaults to None, which sets a default path. |
|
- seed (:obj:`Optional[Union[int, List[int]]]`): Seed or list of seeds for environment reproducibility. Defaults to None. |
|
- debug (:obj:`bool`): Whether to enable the debug mode. Default to False. |
|
Returns: |
|
- An `EvalReturn` object containing evaluation metrics such as mean and standard deviation of returns. |
|
""" |
|
|
|
deply_configs = [self.evaluator_env_cfg[0]] |
|
|
|
if type(seed) == int: |
|
seed_list = [seed] |
|
elif seed: |
|
seed_list = seed |
|
else: |
|
seed_list = [0] |
|
|
|
reward_list = [] |
|
|
|
if enable_save_replay: |
|
replay_save_path = replay_save_path if replay_save_path is not None else os.path.join( |
|
self.exp_name, 'videos' |
|
) |
|
deply_configs[0]['replay_path'] = replay_save_path |
|
|
|
for seed in seed_list: |
|
|
|
evaluator_env = create_env_manager(self.cfg.env.manager, [partial(self.env_fn, cfg=deply_configs[0])]) |
|
|
|
evaluator_env.seed(seed if seed is not None else self.cfg.seed, dynamic_seed=False) |
|
set_pkg_seed(seed if seed is not None else self.cfg.seed, use_cuda=self.cfg.policy.cuda) |
|
|
|
|
|
|
|
|
|
policy_config = self.cfg.policy |
|
|
|
evaluator = Evaluator( |
|
eval_freq=self.cfg.policy.eval_freq, |
|
n_evaluator_episode=1, |
|
stop_value=self.cfg.env.stop_value, |
|
env=evaluator_env, |
|
policy=self.policy.eval_mode, |
|
exp_name=self.cfg.exp_name, |
|
policy_config=policy_config |
|
) |
|
|
|
|
|
|
|
|
|
|
|
stop, reward = evaluator.eval() |
|
reward_list.extend(reward['eval_episode_return']) |
|
|
|
if enable_save_replay: |
|
files = os.listdir(replay_save_path) |
|
files = [file for file in files if file.endswith('0.mp4')] |
|
files.sort() |
|
if concatenate_all_replay: |
|
|
|
with open(os.path.join(replay_save_path, 'files.txt'), 'w') as f: |
|
for file in files: |
|
f.write("file '{}'\n".format(file)) |
|
|
|
|
|
os.system( |
|
'ffmpeg -f concat -safe 0 -i {} -c copy {}/deploy.mp4'.format( |
|
os.path.join(replay_save_path, 'files.txt'), replay_save_path |
|
) |
|
) |
|
|
|
return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list)) |
|
|
|
def batch_evaluate( |
|
self, |
|
n_evaluator_episode: int = None, |
|
) -> EvalReturn: |
|
""" |
|
Overview: |
|
Perform a batch evaluation of the agent over a specified number of episodes: ``n_evaluator_episode``. |
|
Arguments: |
|
- n_evaluator_episode (:obj:`Optional[int]`): Number of episodes to run the evaluation. |
|
If None, uses default value from configuration. Defaults to None. |
|
Returns: |
|
- An `EvalReturn` object with evaluation results such as mean and standard deviation of returns. |
|
|
|
.. note:: |
|
This method evaluates the agent's performance across multiple episodes to gauge its effectiveness. |
|
""" |
|
evaluator_env = create_env_manager( |
|
self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] |
|
) |
|
|
|
evaluator_env.seed(self.cfg.seed, dynamic_seed=False) |
|
set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) |
|
|
|
|
|
|
|
|
|
policy_config = self.cfg.policy |
|
|
|
evaluator = Evaluator( |
|
eval_freq=self.cfg.policy.eval_freq, |
|
n_evaluator_episode=self.cfg.env.n_evaluator_episode |
|
if n_evaluator_episode is None else n_evaluator_episode, |
|
stop_value=self.cfg.env.stop_value, |
|
env=evaluator_env, |
|
policy=self.policy.eval_mode, |
|
exp_name=self.cfg.exp_name, |
|
policy_config=policy_config |
|
) |
|
|
|
|
|
|
|
|
|
|
|
stop, reward = evaluator.eval() |
|
|
|
return EvalReturn( |
|
eval_value=np.mean(reward['eval_episode_return']), eval_value_std=np.std(reward['eval_episode_return']) |
|
) |
|
|
|
@property |
|
def best(self): |
|
""" |
|
Overview: |
|
Provides access to the best model according to evaluation metrics. |
|
Returns: |
|
- The agent with the best model loaded. |
|
|
|
.. note:: |
|
The best model is saved in the path `./exp_name/ckpt/ckpt_best.pth.tar`. |
|
When this property is accessed, the agent instance will load the best model state. |
|
""" |
|
|
|
best_model_file_path = os.path.join(self.checkpoint_save_dir, "ckpt_best.pth.tar") |
|
|
|
if os.path.exists(best_model_file_path): |
|
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu")) |
|
self.policy.learn_mode.load_state_dict(policy_state_dict) |
|
return self |
|
|