File size: 15,509 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import copy
from typing import List, Tuple

import numpy as np
from easydict import EasyDict

from ding.utils.compression_helper import jpeg_data_decompressor


class GameSegment:
    """
    Overview:
        A game segment from a full episode trajectory.

        The length of one episode in (Atari) games is often quite large. This class represents a single game segment
        within a larger trajectory, split into several blocks.

    Interfaces:
        - __init__
        - __len__
        - reset
        - pad_over
        - is_full
        - legal_actions
        - append
        - get_observation
        - zero_obs
        - step_obs
        - get_targets
        - game_segment_to_array
        - store_search_stats
    """

    def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None:
        """
        Overview:
            Init the ``GameSegment`` according to the provided arguments.
        Arguments:
             action_space (:obj:`int`): action space
            - game_segment_length (:obj:`int`): the transition number of one ``GameSegment`` block
        """
        self.action_space = action_space
        self.game_segment_length = game_segment_length
        self.num_unroll_steps = config.num_unroll_steps
        self.td_steps = config.td_steps
        self.frame_stack_num = config.model.frame_stack_num
        self.discount_factor = config.discount_factor
        self.action_space_size = config.model.action_space_size
        self.gray_scale = config.gray_scale
        self.transform2string = config.transform2string
        self.sampled_algo = config.sampled_algo
        self.gumbel_algo = config.gumbel_algo
        self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder

        if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1:
            # for vector obs input, e.g. classical control and box2d environments
            self.zero_obs_shape = config.model.observation_shape
        elif len(config.model.observation_shape) == 3:
            # image obs input, e.g. atari environments
            self.zero_obs_shape = (
                config.model.observation_shape[-2], config.model.observation_shape[-1], config.model.image_channel
            )

        self.obs_segment = []
        self.action_segment = []
        self.reward_segment = []

        self.child_visit_segment = []
        self.root_value_segment = []

        self.action_mask_segment = []
        self.to_play_segment = []

        self.target_values = []
        self.target_rewards = []
        self.target_policies = []

        self.improved_policy_probs = []

        if self.sampled_algo:
            self.root_sampled_actions = []
        if self.use_ture_chance_label_in_chance_encoder:
            self.chance_segment = []


    def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray:
        """
        Overview:
            Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps].
        Arguments:
            - timestep (int): The time step.
            - num_unroll_steps (int): The extra length of the observation frames.
            - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory.
        """
        stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps]
        if padding:
            pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs)
            if pad_len > 0:
                pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)])
                stacked_obs = np.concatenate((stacked_obs, pad_frames))
        if self.transform2string:
            stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
        return stacked_obs

    def zero_obs(self) -> List:
        """
        Overview:
            Return an observation frame filled with zeros.
        Returns:
            ndarray: An array filled with zeros.
        """
        return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)]

    def get_obs(self) -> List:
        """
        Overview:
            Return an observation in the correct format for model inference.
        Returns:
              stacked_obs (List): An observation in the correct format for model inference.
          """
        timestep_obs = len(self.obs_segment) - self.frame_stack_num
        timestep_reward = len(self.reward_segment)
        assert timestep_obs == timestep_reward, "timestep_obs: {}, timestep_reward: {}".format(
            timestep_obs, timestep_reward
        )
        timestep = timestep_reward
        stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num]
        if self.transform2string:
            stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
        return stacked_obs

    def append(
            self,
            action: np.ndarray,
            obs: np.ndarray,
            reward: np.ndarray,
            action_mask: np.ndarray = None,
            to_play: int = -1,
            chance: int = 0,
    ) -> None:
        """
        Overview:
            Append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t}.
        """
        self.action_segment.append(action)
        self.obs_segment.append(obs)
        self.reward_segment.append(reward)

        self.action_mask_segment.append(action_mask)
        self.to_play_segment.append(to_play)
        if self.use_ture_chance_label_in_chance_encoder:
            self.chance_segment.append(chance)

    def pad_over(
            self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List,
            next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None,
    ) -> None:
        """
        Overview:
            To make sure the correction of value targets, we need to add (o_t, r_t, etc) from the next game_segment
            , which is necessary for the bootstrapped values at the end states of previous game_segment.
            e.g: len = 100; target value v_100 = r_100 + gamma^1 r_101 + ... + gamma^4 r_104 + gamma^5 v_105,
            but r_101, r_102, ... are from the next game_segment.
        Arguments:
            - next_segment_observations (:obj:`list`): o_t from the next game_segment
            - next_segment_rewards (:obj:`list`): r_t from the next game_segment
            - next_segment_root_values (:obj:`list`): root values of MCTS from the next game_segment
            - next_segment_child_visits (:obj:`list`): root visit count distributions of MCTS from the next game_segment
            - next_segment_improved_policy (:obj:`list`): root children select policy of MCTS from the next game_segment (Only used in Gumbel MuZero)
        """
        assert len(next_segment_observations) <= self.num_unroll_steps
        assert len(next_segment_child_visits) <= self.num_unroll_steps
        assert len(next_segment_root_values) <= self.num_unroll_steps + self.td_steps
        assert len(next_segment_rewards) <= self.num_unroll_steps + self.td_steps - 1
        # ==============================================================
        # The core difference between GumbelMuZero and MuZero
        # ==============================================================
        if self.gumbel_algo:
            assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.td_steps

        # NOTE: next block observation should start from (stacked_observation - 1) in next trajectory
        for observation in next_segment_observations:
            self.obs_segment.append(copy.deepcopy(observation))

        for reward in next_segment_rewards:
            self.reward_segment.append(reward)

        for value in next_segment_root_values:
            self.root_value_segment.append(value)

        for child_visits in next_segment_child_visits:
            self.child_visit_segment.append(child_visits)
        
        if self.gumbel_algo:
            for improved_policy in next_segment_improved_policy:
                self.improved_policy_probs.append(improved_policy)
        if self.use_ture_chance_label_in_chance_encoder:
            for chances in next_chances:
                self.chance_segment.append(chances)

    def get_targets(self, timestep: int) -> Tuple:
        """
        Overview:
            return the value/reward/policy targets at step timestep
        """
        return self.target_values[timestep], self.target_rewards[timestep], self.target_policies[timestep]

    def store_search_stats(
            self, visit_counts: List, root_value: List, root_sampled_actions: List = None, improved_policy: List = None, idx: int = None
    ) -> None:
        """
        Overview:
            store the visit count distributions and value of the root node after MCTS.
        """
        sum_visits = sum(visit_counts)
        if idx is None:
            self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts])
            self.root_value_segment.append(root_value)
            if self.sampled_algo:
                self.root_sampled_actions.append(root_sampled_actions)
            # store the improved policy in Gumbel Muzero: \pi'=softmax(logits + \sigma(CompletedQ))
            if self.gumbel_algo:
                self.improved_policy_probs.append(improved_policy)
        else:
            self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts]
            self.root_value_segment[idx] = root_value
            self.improved_policy_probs[idx] = improved_policy

    def game_segment_to_array(self) -> None:
        """
        Overview:
            Post-process the data when a `GameSegment` block is full. This function converts various game segment
            elements into numpy arrays for easier manipulation and processing.
        Structure:
            The structure and shapes of different game segment elements are as follows. Let's assume
            `game_segment_length`=20, `stack`=4, `num_unroll_steps`=5, `td_steps`=5:

            - obs:            game_segment_length + stack + num_unroll_steps, 20+4+5
            - action:         game_segment_length -> 20
            - reward:         game_segment_length + num_unroll_steps + td_steps -1  20+5+5-1
            - root_values:    game_segment_length + num_unroll_steps + td_steps -> 20+5+5
            - child_visits:   game_segment_length + num_unroll_steps -> 20+5
            - to_play:        game_segment_length -> 20
            - action_mask:    game_segment_length -> 20
        Examples:
            Here is an illustration of the structure of `obs` and `rew` for two consecutive game segments
            (game_segment_i and game_segment_i+1):

            - game_segment_i (obs):     4       20        5
                                      ----|----...----|-----|
            - game_segment_i+1 (obs):              4       20        5
                                                  ----|----...----|-----|

            - game_segment_i (rew):        20        5      4
                                      ----...----|------|-----|
            - game_segment_i+1 (rew):                 20        5    4
                                                 ----...----|------|-----|

        Postprocessing:
            - self.obs_segment (:obj:`numpy.ndarray`): A numpy array version of the original obs_segment.
            - self.action_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_segment.
            - self.reward_segment (:obj:`numpy.ndarray`): A numpy array version of the original reward_segment.
            - self.child_visit_segment (:obj:`numpy.ndarray`): A numpy array version of the original child_visit_segment.
            - self.root_value_segment (:obj:`numpy.ndarray`): A numpy array version of the original root_value_segment.
            - self.improved_policy_probs (:obj:`numpy.ndarray`): A numpy array version of the original improved_policy_probs.
            - self.action_mask_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_mask_segment.
            - self.to_play_segment (:obj:`numpy.ndarray`): A numpy array version of the original to_play_segment.
            - self.chance_segment (:obj:`numpy.ndarray`, optional): A numpy array version of the original chance_segment. Only
               created if `self.use_ture_chance_label_in_chance_encoder` is True.

        .. note::
            For environments with a variable action space, such as board games, the elements in `child_visit_segment` may have
            different lengths. In such scenarios, it is necessary to use the object data type for `self.child_visit_segment`.
        """
        self.obs_segment = np.array(self.obs_segment)
        self.action_segment = np.array(self.action_segment)
        self.reward_segment = np.array(self.reward_segment)

        # Check if all elements in self.child_visit_segment have the same length
        if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment):
            self.child_visit_segment = np.array(self.child_visit_segment)
        else:
            # In the case of environments with a variable action space, such as board games,
            # the elements in child_visit_segment may have different lengths.
            # In such scenarios, it is necessary to use the object data type.
            self.child_visit_segment = np.array(self.child_visit_segment, dtype=object)

        self.root_value_segment = np.array(self.root_value_segment)
        self.improved_policy_probs = np.array(self.improved_policy_probs)

        self.action_mask_segment = np.array(self.action_mask_segment)
        self.to_play_segment = np.array(self.to_play_segment)
        if self.use_ture_chance_label_in_chance_encoder:
            self.chance_segment = np.array(self.chance_segment)

    def reset(self, init_observations: np.ndarray) -> None:
        """
        Overview:
            Initialize the game segment using ``init_observations``,
            which is the previous ``frame_stack_num`` stacked frames.
        Arguments:
            - init_observations (:obj:`list`): list of the stack observations in the previous time steps.
        """
        self.obs_segment = []
        self.action_segment = []
        self.reward_segment = []

        self.child_visit_segment = []
        self.root_value_segment = []

        self.action_mask_segment = []
        self.to_play_segment = []
        if self.use_ture_chance_label_in_chance_encoder:
            self.chance_segment = []

        assert len(init_observations) == self.frame_stack_num

        for observation in init_observations:
            self.obs_segment.append(copy.deepcopy(observation))

    def is_full(self) -> bool:
        """
        Overview:
            Check whether the current game segment is full, i.e. larger than the segment length.
        Returns:
            bool: True if the game segment is full, False otherwise.
        """
        return len(self.action_segment) >= self.game_segment_length

    def legal_actions(self):
        return [_ for _ in range(self.action_space.n)]

    def __len__(self):
        return len(self.action_segment)