import copy from typing import List, Tuple import numpy as np from easydict import EasyDict from ding.utils.compression_helper import jpeg_data_decompressor class GameSegment: """ Overview: A game segment from a full episode trajectory. The length of one episode in (Atari) games is often quite large. This class represents a single game segment within a larger trajectory, split into several blocks. Interfaces: - __init__ - __len__ - reset - pad_over - is_full - legal_actions - append - get_observation - zero_obs - step_obs - get_targets - game_segment_to_array - store_search_stats """ def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None: """ Overview: Init the ``GameSegment`` according to the provided arguments. Arguments: action_space (:obj:`int`): action space - game_segment_length (:obj:`int`): the transition number of one ``GameSegment`` block """ self.action_space = action_space self.game_segment_length = game_segment_length self.num_unroll_steps = config.num_unroll_steps self.td_steps = config.td_steps self.frame_stack_num = config.model.frame_stack_num self.discount_factor = config.discount_factor self.action_space_size = config.model.action_space_size self.gray_scale = config.gray_scale self.transform2string = config.transform2string self.sampled_algo = config.sampled_algo self.gumbel_algo = config.gumbel_algo self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: # for vector obs input, e.g. classical control and box2d environments self.zero_obs_shape = config.model.observation_shape elif len(config.model.observation_shape) == 3: # image obs input, e.g. atari environments self.zero_obs_shape = ( config.model.observation_shape[-2], config.model.observation_shape[-1], config.model.image_channel ) self.obs_segment = [] self.action_segment = [] self.reward_segment = [] self.child_visit_segment = [] self.root_value_segment = [] self.action_mask_segment = [] self.to_play_segment = [] self.target_values = [] self.target_rewards = [] self.target_policies = [] self.improved_policy_probs = [] if self.sampled_algo: self.root_sampled_actions = [] if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = [] def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: """ Overview: Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps]. Arguments: - timestep (int): The time step. - num_unroll_steps (int): The extra length of the observation frames. - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory. """ stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps] if padding: pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs) if pad_len > 0: pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)]) stacked_obs = np.concatenate((stacked_obs, pad_frames)) if self.transform2string: stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] return stacked_obs def zero_obs(self) -> List: """ Overview: Return an observation frame filled with zeros. Returns: ndarray: An array filled with zeros. """ return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)] def get_obs(self) -> List: """ Overview: Return an observation in the correct format for model inference. Returns: stacked_obs (List): An observation in the correct format for model inference. """ timestep_obs = len(self.obs_segment) - self.frame_stack_num timestep_reward = len(self.reward_segment) assert timestep_obs == timestep_reward, "timestep_obs: {}, timestep_reward: {}".format( timestep_obs, timestep_reward ) timestep = timestep_reward stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num] if self.transform2string: stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] return stacked_obs def append( self, action: np.ndarray, obs: np.ndarray, reward: np.ndarray, action_mask: np.ndarray = None, to_play: int = -1, chance: int = 0, ) -> None: """ Overview: Append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t}. """ self.action_segment.append(action) self.obs_segment.append(obs) self.reward_segment.append(reward) self.action_mask_segment.append(action_mask) self.to_play_segment.append(to_play) if self.use_ture_chance_label_in_chance_encoder: self.chance_segment.append(chance) def pad_over( self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List, next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None, ) -> None: """ Overview: To make sure the correction of value targets, we need to add (o_t, r_t, etc) from the next game_segment , which is necessary for the bootstrapped values at the end states of previous game_segment. e.g: len = 100; target value v_100 = r_100 + gamma^1 r_101 + ... + gamma^4 r_104 + gamma^5 v_105, but r_101, r_102, ... are from the next game_segment. Arguments: - next_segment_observations (:obj:`list`): o_t from the next game_segment - next_segment_rewards (:obj:`list`): r_t from the next game_segment - next_segment_root_values (:obj:`list`): root values of MCTS from the next game_segment - next_segment_child_visits (:obj:`list`): root visit count distributions of MCTS from the next game_segment - next_segment_improved_policy (:obj:`list`): root children select policy of MCTS from the next game_segment (Only used in Gumbel MuZero) """ assert len(next_segment_observations) <= self.num_unroll_steps assert len(next_segment_child_visits) <= self.num_unroll_steps assert len(next_segment_root_values) <= self.num_unroll_steps + self.td_steps assert len(next_segment_rewards) <= self.num_unroll_steps + self.td_steps - 1 # ============================================================== # The core difference between GumbelMuZero and MuZero # ============================================================== if self.gumbel_algo: assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.td_steps # NOTE: next block observation should start from (stacked_observation - 1) in next trajectory for observation in next_segment_observations: self.obs_segment.append(copy.deepcopy(observation)) for reward in next_segment_rewards: self.reward_segment.append(reward) for value in next_segment_root_values: self.root_value_segment.append(value) for child_visits in next_segment_child_visits: self.child_visit_segment.append(child_visits) if self.gumbel_algo: for improved_policy in next_segment_improved_policy: self.improved_policy_probs.append(improved_policy) if self.use_ture_chance_label_in_chance_encoder: for chances in next_chances: self.chance_segment.append(chances) def get_targets(self, timestep: int) -> Tuple: """ Overview: return the value/reward/policy targets at step timestep """ return self.target_values[timestep], self.target_rewards[timestep], self.target_policies[timestep] def store_search_stats( self, visit_counts: List, root_value: List, root_sampled_actions: List = None, improved_policy: List = None, idx: int = None ) -> None: """ Overview: store the visit count distributions and value of the root node after MCTS. """ sum_visits = sum(visit_counts) if idx is None: self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts]) self.root_value_segment.append(root_value) if self.sampled_algo: self.root_sampled_actions.append(root_sampled_actions) # store the improved policy in Gumbel Muzero: \pi'=softmax(logits + \sigma(CompletedQ)) if self.gumbel_algo: self.improved_policy_probs.append(improved_policy) else: self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts] self.root_value_segment[idx] = root_value self.improved_policy_probs[idx] = improved_policy def game_segment_to_array(self) -> None: """ Overview: Post-process the data when a `GameSegment` block is full. This function converts various game segment elements into numpy arrays for easier manipulation and processing. Structure: The structure and shapes of different game segment elements are as follows. Let's assume `game_segment_length`=20, `stack`=4, `num_unroll_steps`=5, `td_steps`=5: - obs: game_segment_length + stack + num_unroll_steps, 20+4+5 - action: game_segment_length -> 20 - reward: game_segment_length + num_unroll_steps + td_steps -1 20+5+5-1 - root_values: game_segment_length + num_unroll_steps + td_steps -> 20+5+5 - child_visits: game_segment_length + num_unroll_steps -> 20+5 - to_play: game_segment_length -> 20 - action_mask: game_segment_length -> 20 Examples: Here is an illustration of the structure of `obs` and `rew` for two consecutive game segments (game_segment_i and game_segment_i+1): - game_segment_i (obs): 4 20 5 ----|----...----|-----| - game_segment_i+1 (obs): 4 20 5 ----|----...----|-----| - game_segment_i (rew): 20 5 4 ----...----|------|-----| - game_segment_i+1 (rew): 20 5 4 ----...----|------|-----| Postprocessing: - self.obs_segment (:obj:`numpy.ndarray`): A numpy array version of the original obs_segment. - self.action_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_segment. - self.reward_segment (:obj:`numpy.ndarray`): A numpy array version of the original reward_segment. - self.child_visit_segment (:obj:`numpy.ndarray`): A numpy array version of the original child_visit_segment. - self.root_value_segment (:obj:`numpy.ndarray`): A numpy array version of the original root_value_segment. - self.improved_policy_probs (:obj:`numpy.ndarray`): A numpy array version of the original improved_policy_probs. - self.action_mask_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_mask_segment. - self.to_play_segment (:obj:`numpy.ndarray`): A numpy array version of the original to_play_segment. - self.chance_segment (:obj:`numpy.ndarray`, optional): A numpy array version of the original chance_segment. Only created if `self.use_ture_chance_label_in_chance_encoder` is True. .. note:: For environments with a variable action space, such as board games, the elements in `child_visit_segment` may have different lengths. In such scenarios, it is necessary to use the object data type for `self.child_visit_segment`. """ self.obs_segment = np.array(self.obs_segment) self.action_segment = np.array(self.action_segment) self.reward_segment = np.array(self.reward_segment) # Check if all elements in self.child_visit_segment have the same length if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment): self.child_visit_segment = np.array(self.child_visit_segment) else: # In the case of environments with a variable action space, such as board games, # the elements in child_visit_segment may have different lengths. # In such scenarios, it is necessary to use the object data type. self.child_visit_segment = np.array(self.child_visit_segment, dtype=object) self.root_value_segment = np.array(self.root_value_segment) self.improved_policy_probs = np.array(self.improved_policy_probs) self.action_mask_segment = np.array(self.action_mask_segment) self.to_play_segment = np.array(self.to_play_segment) if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = np.array(self.chance_segment) def reset(self, init_observations: np.ndarray) -> None: """ Overview: Initialize the game segment using ``init_observations``, which is the previous ``frame_stack_num`` stacked frames. Arguments: - init_observations (:obj:`list`): list of the stack observations in the previous time steps. """ self.obs_segment = [] self.action_segment = [] self.reward_segment = [] self.child_visit_segment = [] self.root_value_segment = [] self.action_mask_segment = [] self.to_play_segment = [] if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = [] assert len(init_observations) == self.frame_stack_num for observation in init_observations: self.obs_segment.append(copy.deepcopy(observation)) def is_full(self) -> bool: """ Overview: Check whether the current game segment is full, i.e. larger than the segment length. Returns: bool: True if the game segment is full, False otherwise. """ return len(self.action_segment) >= self.game_segment_length def legal_actions(self): return [_ for _ in range(self.action_space.n)] def __len__(self): return len(self.action_segment)