File size: 14,651 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
from typing import List, Dict, Any, Optional, Tuple, Union

import numpy as np
import torch
from ding.policy.base_policy import Policy
from ding.utils import POLICY_REGISTRY

from lzero.policy import InverseScalarTransform, select_action, ez_network_output_unpack, mz_network_output_unpack


@POLICY_REGISTRY.register('lightzero_random_policy')
class LightZeroRandomPolicy(Policy):
    """
    Overview:
        The policy class for LightZero RandomPolicy.
    """

    def __init__(
        self,
        cfg: dict,
        model: Optional[Union[type, torch.nn.Module]] = None,
        enable_field: Optional[List[str]] = None,
        action_space: Any = None,
    ):
        if cfg.type == 'muzero':
            from lzero.mcts import MuZeroMCTSCtree as MCTSCtree
            from lzero.mcts import MuZeroMCTSPtree as MCTSPtree
        elif cfg.type == 'efficientzero':
            from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree
            from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree
        elif cfg.type == 'sampled_efficientzero':
            from lzero.mcts import SampledEfficientZeroMCTSCtree as MCTSCtree
            from lzero.mcts import SampledEfficientZeroMCTSPtree as MCTSPtree
        else:
            raise NotImplementedError("need to implement pipeline: {}".format(cfg.type))
        self.MCTSCtree = MCTSCtree
        self.MCTSPtree = MCTSPtree
        self.action_space = action_space
        super().__init__(cfg, model, enable_field)

    def default_model(self) -> Tuple[str, List[str]]:
        """
        Overview:
            Return this algorithm default model setting.
        Returns:
            - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names.
                - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry.
                - import_names (:obj:`List[str]`): The model class path list used in this algorithm.
        .. note::
            The user can define and use customized network model but must obey the same interface definition indicated \
            by import_names path. For EfficientZero, ``lzero.model.efficientzero_model.EfficientZeroModel``
        """
        if self._cfg.model.model_type == "conv":
            if self._cfg.type == 'efficientzero':
                return 'EfficientZeroModel', ['lzero.model.efficientzero_model']
            elif self._cfg.type == 'muzero':
                return 'MuZeroModel', ['lzero.model.muzero_model']
            elif self._cfg.type == 'sampled_efficientzero':
                return 'SampledEfficientZeroModel', ['lzero.model.sampled_efficientzero_model']
            else:
                raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type))
        elif self._cfg.model.model_type == "mlp":
            if self._cfg.type == 'efficientzero':
                return 'EfficientZeroModelMLP', ['lzero.model.efficientzero_model_mlp']
            elif self._cfg.type == 'muzero':
                return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp']
            elif self._cfg.type == 'sampled_efficientzero':
                return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp']
            else:
                raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type))

    def _init_collect(self) -> None:
        """
        Overview:
            Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils.
        """
        self._collect_model = self._model
        if self._cfg.mcts_ctree:
            self._mcts_collect = self.MCTSCtree(self._cfg)
        else:
            self._mcts_collect = self.MCTSPtree(self._cfg)
        self._collect_mcts_temperature = 1
        self.collect_epsilon = 0.0
        self.inverse_scalar_transform_handle = InverseScalarTransform(
            self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution
        )

    def _forward_collect(
        self,
        data: torch.Tensor,
        action_mask: list = None,
        temperature: float = 1,
        to_play: List = [-1],
        epsilon: float = 0.25,
        ready_env_id: np.array = None,
    ) -> Dict:
        """
        Overview:
            The forward function for collecting data in collect mode. Use model to execute MCTS search.
            Choosing the action through sampling during the collect mode.
        Arguments:
            - data (:obj:`torch.Tensor`): The input data, i.e. the observation.
            - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected.
            - temperature (:obj:`float`): The temperature of the policy.
            - to_play (:obj:`int`): The player to play.
            - ready_env_id (:obj:`list`): The id of the env that is ready to collect.
        Shape:
            - data (:obj:`torch.Tensor`):
                - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \
                    S is the number of stacked frames, H is the height of the image, W is the width of the image.
                - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size.
            - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env.
            - temperature: :math:`(1, )`.
            - to_play: :math:`(N, 1)`, where N is the number of collect_env.
            - ready_env_id: None
        Returns:
            - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \
                ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``.
        """
        self._collect_model.eval()
        self._collect_mcts_temperature = temperature
        active_collect_env_num = data.shape[0]
        with torch.no_grad():
            # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)}
            network_output = self._collect_model.initial_inference(data)
            if self._cfg.type in ['efficientzero', 'sampled_efficientzero']:
                latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack(
                    network_output
                )
            elif self._cfg.type == 'muzero':
                latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)
            else:
                raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type))

            pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy()
            latent_state_roots = latent_state_roots.detach().cpu().numpy()
            if self._cfg.type in ['efficientzero', 'sampled_efficientzero']:
                reward_hidden_state_roots = (
                    reward_hidden_state_roots[0].detach().cpu().numpy(),
                    reward_hidden_state_roots[1].detach().cpu().numpy()
                )
            policy_logits = policy_logits.detach().cpu().numpy().tolist()

            if self._cfg.model.continuous_action_space:
                # when the action space of the environment is continuous, action_mask[:] is None.
                # NOTE: in continuous action space env: we set all legal_actions as -1
                legal_actions = [
                    [-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(active_collect_env_num)
                ]
            else:
                legal_actions = [
                    [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)
                ]

            # the only difference between collect and eval is the dirichlet noise.
            if self._cfg.type in ['sampled_efficientzero']:
                noises = [
                    np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(self._cfg.model.num_of_sampled_actions)
                                        ).astype(np.float32).tolist() for j in range(active_collect_env_num)
                ]
            else:
                noises = [
                    np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j]))
                                        ).astype(np.float32).tolist() for j in range(active_collect_env_num)
                ]

            if self._cfg.mcts_ctree:
                # cpp mcts_tree
                if self._cfg.type in ['sampled_efficientzero']:
                    roots = self.MCTSCtree.roots(
                        active_collect_env_num, legal_actions, self._cfg.model.action_space_size,
                        self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
                    )
                else:
                    roots = self.MCTSCtree.roots(active_collect_env_num, legal_actions)
            else:
                # python mcts_tree
                if self._cfg.type in ['sampled_efficientzero']:
                    roots = self.MCTSPtree.roots(
                        active_collect_env_num, legal_actions, self._cfg.model.action_space_size,
                        self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
                    )
                else:
                    roots = self.MCTSPtree.roots(active_collect_env_num, legal_actions)

            if self._cfg.type in ['efficientzero', 'sampled_efficientzero']:
                roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play)
                self._mcts_collect.search(
                    roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play
                )
            elif self._cfg.type == 'muzero':
                roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play)
                self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play)
            else:
                raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type))

            roots_visit_count_distributions = roots.get_distributions()
            roots_values = roots.get_values()  # shape: {list: batch_size}
            if self._cfg.type in ['sampled_efficientzero']:
                roots_sampled_actions = roots.get_sampled_actions()

            data_id = [i for i in range(active_collect_env_num)]
            output = {i: None for i in data_id}
            if ready_env_id is None:
                ready_env_id = np.arange(active_collect_env_num)

            for i, env_id in enumerate(ready_env_id):
                distributions, value = roots_visit_count_distributions[i], roots_values[i]

                if self._cfg.type in ['sampled_efficientzero']:
                    if self._cfg.mcts_ctree:
                        # In ctree, the method roots.get_sampled_actions() returns a list object.
                        root_sampled_actions = np.array([action for action in roots_sampled_actions[i]])
                    else:
                        # In ptree, the same method roots.get_sampled_actions() returns an Action object.
                        root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]])

                # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
                # the index within the legal action set, rather than the index in the entire action set.
                action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
                    distributions, temperature=self._collect_mcts_temperature, deterministic=False
                )

                # ****************************************************************
                # NOTE: The action is randomly selected from the legal action set, 
                # the distribution is the real visit count distribution from the MCTS search.
                if self._cfg.type in ['sampled_efficientzero']:
                    # ****** sample a random action from the legal action set ********
                    random_action = self.action_space.sample()
                    output[env_id] = {
                        'action': random_action,
                        'visit_count_distributions': distributions,
                        'root_sampled_actions': root_sampled_actions,
                        'visit_count_distribution_entropy': visit_count_distribution_entropy,
                        'searched_value': value,
                        'predicted_value': pred_values[i],
                        'predicted_policy_logits': policy_logits[i],
                    }
                else:
                    # ****** sample a random action from the legal action set ********
                    random_action = int(np.random.choice(legal_actions[env_id], 1))
                    # all items except action are formally obtained from MCTS
                    output[env_id] = {
                        'action': random_action,
                        'visit_count_distributions': distributions,
                        'visit_count_distribution_entropy': visit_count_distribution_entropy,
                        'searched_value': value,
                        'predicted_value': pred_values[i],
                        'predicted_policy_logits': policy_logits[i],
                    }

        return output

    def _init_eval(self) -> None:
        """
        Overview:
            Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils.
        """
        self._eval_model = self._model
        if self._cfg.mcts_ctree:
            self._mcts_eval = self.MCTSCtree(self._cfg)
        else:
            self._mcts_eval = self.MCTSPtree(self._cfg)

    # be compatible with DI-engine Policy class
    def _init_learn(self) -> None:
        pass

    def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
        pass

    def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,):
        pass

    def _monitor_vars_learn(self) -> List[str]:
        pass

    def _state_dict_learn(self) -> Dict[str, Any]:
        pass

    def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
        pass

    def _process_transition(self, obs, policy_output, timestep):
        pass

    def _get_train_sample(self, data):
        pass