File size: 19,342 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
import os
from functools import partial
from typing import Optional, Union, List

import numpy as np
import torch
from ding.bonus.common import TrainingReturn, EvalReturn
from ding.config import save_config_py, compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import set_pkg_seed, get_rank
from ding.worker import BaseLearner
from ditk import logging
from easydict import EasyDict
from tensorboardX import SummaryWriter

from lzero.agent.config.muzero import supported_env_cfg
from lzero.entry.utils import log_buffer_memory_usage, random_collect
from lzero.mcts import MuZeroGameBuffer
from lzero.policy import visit_count_temperature
from lzero.policy.muzero import MuZeroPolicy
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator


class MuZeroAgent:
    """
    Overview:
        Agent class for executing MuZero algorithms which include methods for training, deployment, and batch evaluation.
    Interfaces:
        __init__, train, deploy, batch_evaluate
    Properties:
        best

    .. note::
        This agent class is tailored for use with the HuggingFace Model Zoo for LightZero
        (e.g. https://huggingface.co/OpenDILabCommunity/CartPole-v0-MuZero),
         and provides methods such as "train" and "deploy".
    """

    supported_env_list = list(supported_env_cfg.keys())

    def __init__(
            self,
            env_id: str = None,
            seed: int = 0,
            exp_name: str = None,
            model: Optional[torch.nn.Module] = None,
            cfg: Optional[Union[EasyDict, dict]] = None,
            policy_state_dict: str = None,
    ) -> None:
        """
        Overview:
            Initialize the MuZeroAgent instance with environment parameters, model, and configuration.
        Arguments:
            - env_id (:obj:`str`): Identifier for the environment to be used, registered in gym.
            - seed (:obj:`int`): Random seed for reproducibility. Defaults to 0.
            - exp_name (:obj:`Optional[str]`): Name for the experiment. Defaults to None.
            - model (:obj:`Optional[torch.nn.Module]`): PyTorch module to be used as the model. If None, a default model is created. Defaults to None.
            - cfg (:obj:`Optional[Union[EasyDict, dict]]`): Configuration for the agent. If None, default configuration will be used. Defaults to None.
            - policy_state_dict (:obj:`Optional[str]`): Path to a pre-trained model state dictionary. If provided, state dict will be loaded. Defaults to None.

        .. note::
            - If `env_id` is not specified, it must be included in `cfg`.
            - The `supported_env_list` contains all the environment IDs that are supported by this agent.
        """
        assert env_id is not None or cfg["main_config"]["env_id"] is not None, "Please specify env_id or cfg."

        if cfg is not None and not isinstance(cfg, EasyDict):
            cfg = EasyDict(cfg)

        if env_id is not None:
            assert env_id in MuZeroAgent.supported_env_list, "Please use supported envs: {}".format(
                MuZeroAgent.supported_env_list
            )
            if cfg is None:
                cfg = supported_env_cfg[env_id]
            else:
                assert cfg.main_config.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
        else:
            assert hasattr(cfg.main_config.env, "env_id"), "Please specify env_id in cfg."
            assert cfg.main_config.env.env_id in MuZeroAgent.supported_env_list, "Please use supported envs: {}".format(
                MuZeroAgent.supported_env_list
            )
        default_policy_config = EasyDict({"policy": MuZeroPolicy.default_config()})
        default_policy_config.policy.update(cfg.main_config.policy)
        cfg.main_config.policy = default_policy_config.policy

        if exp_name is not None:
            cfg.main_config.exp_name = exp_name
        self.origin_cfg = cfg
        self.cfg = compile_config(
            cfg.main_config, seed=seed, env=None, auto=True, policy=MuZeroPolicy, create_cfg=cfg.create_config
        )
        self.exp_name = self.cfg.exp_name

        logging.getLogger().setLevel(logging.INFO)
        self.seed = seed
        set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
        if not os.path.exists(self.exp_name):
            os.makedirs(self.exp_name)
        save_config_py(cfg, os.path.join(self.exp_name, 'policy_config.py'))
        if model is None:
            if self.cfg.policy.model.model_type == 'mlp':
                from lzero.model.muzero_model_mlp import MuZeroModelMLP
                model = MuZeroModelMLP(**self.cfg.policy.model)
            elif self.cfg.policy.model.model_type == 'conv':
                from lzero.model.muzero_model import MuZeroModel
                model = MuZeroModel(**self.cfg.policy.model)
            else:
                raise NotImplementedError
        if self.cfg.policy.cuda and torch.cuda.is_available():
            self.cfg.policy.device = 'cuda'
        else:
            self.cfg.policy.device = 'cpu'
        self.policy = create_policy(self.cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
        if policy_state_dict is not None:
            self.policy.learn_mode.load_state_dict(policy_state_dict)
        self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")

        self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env)

    def train(
        self,
        step: int = int(1e7),
    ) -> TrainingReturn:
        """
        Overview:
            Train the agent through interactions with the environment.
        Arguments:
            - step (:obj:`int`): Total number of environment steps to train for. Defaults to 10 million (1e7).
        Returns:
            - A `TrainingReturn` object containing training information, such as logs and potentially a URL to a training dashboard.
        .. note::
            The method involves interacting with the environment, collecting experience, and optimizing the model.
        """

        collector_env = create_env_manager(
            self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.collector_env_cfg]
        )
        evaluator_env = create_env_manager(
            self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg]
        )

        collector_env.seed(self.cfg.seed)
        evaluator_env.seed(self.cfg.seed, dynamic_seed=False)
        set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda)

        # Create worker components: learner, collector, evaluator, replay buffer, commander.
        tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(self.cfg.exp_name), 'serial')
                                  ) if get_rank() == 0 else None
        learner = BaseLearner(
            self.cfg.policy.learn.learner, self.policy.learn_mode, tb_logger, exp_name=self.cfg.exp_name
        )

        # ==============================================================
        # MCTS+RL algorithms related core code
        # ==============================================================
        policy_config = self.cfg.policy
        batch_size = policy_config.batch_size
        # specific game buffer for MCTS+RL algorithms
        replay_buffer = MuZeroGameBuffer(policy_config)
        collector = Collector(
            env=collector_env,
            policy=self.policy.collect_mode,
            tb_logger=tb_logger,
            exp_name=self.cfg.exp_name,
            policy_config=policy_config
        )
        evaluator = Evaluator(
            eval_freq=self.cfg.policy.eval_freq,
            n_evaluator_episode=self.cfg.env.n_evaluator_episode,
            stop_value=self.cfg.env.stop_value,
            env=evaluator_env,
            policy=self.policy.eval_mode,
            tb_logger=tb_logger,
            exp_name=self.cfg.exp_name,
            policy_config=policy_config
        )

        # ==============================================================
        # Main loop
        # ==============================================================
        # Learner's before_run hook.
        learner.call_hook('before_run')

        if self.cfg.policy.update_per_collect is not None:
            update_per_collect = self.cfg.policy.update_per_collect

        # The purpose of collecting random data before training:
        # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
        # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
        if self.cfg.policy.random_collect_episode_num > 0:
            random_collect(self.cfg.policy, self.policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)

        while True:
            log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
            collect_kwargs = {}
            # set temperature for visit count distributions according to the train_iter,
            # please refer to Appendix D in MuZero paper for details.
            collect_kwargs['temperature'] = visit_count_temperature(
                policy_config.manual_temperature_decay,
                policy_config.fixed_temperature_value,
                policy_config.threshold_training_steps_for_final_temperature,
                trained_steps=learner.train_iter
            )

            if policy_config.eps.eps_greedy_exploration_in_collect:
                epsilon_greedy_fn = get_epsilon_greedy_fn(
                    start=policy_config.eps.start,
                    end=policy_config.eps.end,
                    decay=policy_config.eps.decay,
                    type_=policy_config.eps.type
                )
                collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
            else:
                collect_kwargs['epsilon'] = 0.0

            # Evaluate policy performance.
            if evaluator.should_eval(learner.train_iter):
                stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
                if stop:
                    break

            # Collect data by default config n_sample/n_episode.
            new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
            if self.cfg.policy.update_per_collect is None:
                # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
                collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
                update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
            # save returned new_data collected by the collector
            replay_buffer.push_game_segments(new_data)
            # remove the oldest data if the replay buffer is full.
            replay_buffer.remove_oldest_data_to_fit()

            # Learn policy from collected data.
            for i in range(update_per_collect):
                # Learner will train ``update_per_collect`` times in one iteration.
                if replay_buffer.get_num_of_transitions() > batch_size:
                    train_data = replay_buffer.sample(batch_size, self.policy)
                else:
                    logging.warning(
                        f'The data in replay_buffer is not sufficient to sample a mini-batch: '
                        f'batch_size: {batch_size}, '
                        f'{replay_buffer} '
                        f'continue to collect now ....'
                    )
                    break

                # The core train steps for MCTS+RL algorithms.
                log_vars = learner.train(train_data, collector.envstep)

                if self.cfg.policy.use_priority:
                    replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

            if collector.envstep >= step:
                break

        # Learner's after_run hook.
        learner.call_hook('after_run')

        return TrainingReturn(wandb_url=None)

    def deploy(
            self,
            enable_save_replay: bool = False,
            concatenate_all_replay: bool = False,
            replay_save_path: str = None,
            seed: Optional[Union[int, List]] = None,
            debug: bool = False
    ) -> EvalReturn:
        """
        Overview:
            Deploy the agent for evaluation in the environment, with optional replay saving. The performance of the
            agent will be evaluated. Average return and standard deviation of the return will be returned. 
            If `enable_save_replay` is True, replay videos are saved in the specified `replay_save_path`.
        Arguments:
            - enable_save_replay (:obj:`bool`): Flag to enable saving of replay footage. Defaults to False.
            - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one file. Defaults to False.
            - replay_save_path (:obj:`Optional[str]`): Directory path to save replay videos. Defaults to None, which sets a default path.
            - seed (:obj:`Optional[Union[int, List[int]]]`): Seed or list of seeds for environment reproducibility. Defaults to None.
            - debug (:obj:`bool`): Whether to enable the debug mode. Default to False.
        Returns:
            - An `EvalReturn` object containing evaluation metrics such as mean and standard deviation of returns.
        """

        deply_configs = [self.evaluator_env_cfg[0]]

        if type(seed) == int:
            seed_list = [seed]
        elif seed:
            seed_list = seed
        else:
            seed_list = [0]

        reward_list = []

        if enable_save_replay:
            replay_save_path = replay_save_path if replay_save_path is not None else os.path.join(
                self.exp_name, 'videos'
            )
            deply_configs[0]['replay_path'] = replay_save_path

        for seed in seed_list:

            evaluator_env = create_env_manager(self.cfg.env.manager, [partial(self.env_fn, cfg=deply_configs[0])])

            evaluator_env.seed(seed if seed is not None else self.cfg.seed, dynamic_seed=False)
            set_pkg_seed(seed if seed is not None else self.cfg.seed, use_cuda=self.cfg.policy.cuda)

            # ==============================================================
            # MCTS+RL algorithms related core code
            # ==============================================================
            policy_config = self.cfg.policy

            evaluator = Evaluator(
                eval_freq=self.cfg.policy.eval_freq,
                n_evaluator_episode=1,
                stop_value=self.cfg.env.stop_value,
                env=evaluator_env,
                policy=self.policy.eval_mode,
                exp_name=self.cfg.exp_name,
                policy_config=policy_config
            )

            # ==============================================================
            # Main loop
            # ==============================================================

            stop, reward = evaluator.eval()
            reward_list.extend(reward['eval_episode_return'])

        if enable_save_replay:
            files = os.listdir(replay_save_path)
            files = [file for file in files if file.endswith('0.mp4')]
            files.sort()
            if concatenate_all_replay:
                # create a file named 'files.txt' to store the names of all mp4 files
                with open(os.path.join(replay_save_path, 'files.txt'), 'w') as f:
                    for file in files:
                        f.write("file '{}'\n".format(file))

                # combine all the mp4 files into one mp4 file, rename it as 'deploy.mp4'
                os.system(
                    'ffmpeg -f concat -safe 0 -i {} -c copy {}/deploy.mp4'.format(
                        os.path.join(replay_save_path, 'files.txt'), replay_save_path
                    )
                )

        return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list))

    def batch_evaluate(
        self,
        n_evaluator_episode: int = None,
    ) -> EvalReturn:
        """
        Overview:
            Perform a batch evaluation of the agent over a specified number of episodes: ``n_evaluator_episode``.
        Arguments:
            - n_evaluator_episode (:obj:`Optional[int]`): Number of episodes to run the evaluation.
                If None, uses default value from configuration. Defaults to None.
        Returns:
            - An `EvalReturn` object with evaluation results such as mean and standard deviation of returns.

        .. note::
            This method evaluates the agent's performance across multiple episodes to gauge its effectiveness.
        """
        evaluator_env = create_env_manager(
            self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg]
        )

        evaluator_env.seed(self.cfg.seed, dynamic_seed=False)
        set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda)

        # ==============================================================
        # MCTS+RL algorithms related core code
        # ==============================================================
        policy_config = self.cfg.policy

        evaluator = Evaluator(
            eval_freq=self.cfg.policy.eval_freq,
            n_evaluator_episode=self.cfg.env.n_evaluator_episode
            if n_evaluator_episode is None else n_evaluator_episode,
            stop_value=self.cfg.env.stop_value,
            env=evaluator_env,
            policy=self.policy.eval_mode,
            exp_name=self.cfg.exp_name,
            policy_config=policy_config
        )

        # ==============================================================
        # Main loop
        # ==============================================================

        stop, reward = evaluator.eval()

        return EvalReturn(
            eval_value=np.mean(reward['eval_episode_return']), eval_value_std=np.std(reward['eval_episode_return'])
        )

    @property
    def best(self):
        """
        Overview:
            Provides access to the best model according to evaluation metrics.
        Returns:
            - The agent with the best model loaded.

        .. note::
            The best model is saved in the path `./exp_name/ckpt/ckpt_best.pth.tar`.
            When this property is accessed, the agent instance will load the best model state.
        """

        best_model_file_path = os.path.join(self.checkpoint_save_dir, "ckpt_best.pth.tar")
        # Load best model if it exists
        if os.path.exists(best_model_file_path):
            policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
            self.policy.learn_mode.load_state_dict(policy_state_dict)
        return self