from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional import numpy as np import torch from ding.utils import BUFFER_REGISTRY from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer import GameBuffer if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy @BUFFER_REGISTRY.register('game_buffer_muzero') class MuZeroGameBuffer(GameBuffer): """ Overview: The specific game buffer for MuZero policy. """ def __init__(self, cfg: dict): super().__init__(cfg) """ Overview: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key in the default configuration, the user-provided value will override the default configuration. Otherwise, the default configuration will be used. """ default_config = self.default_config() default_config.update(cfg) self._cfg = default_config assert self._cfg.env_type in ['not_board_games', 'board_games'] assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space'] self.replay_buffer_size = self._cfg.replay_buffer_size self.batch_size = self._cfg.batch_size self._alpha = self._cfg.priority_prob_alpha self._beta = self._cfg.priority_prob_beta self.keep_ratio = 1 self.model_update_interval = 10 self.num_of_collected_episodes = 0 self.base_idx = 0 self.clear_time = 0 self.game_segment_buffer = [] self.game_pos_priorities = [] self.game_segment_game_pos_look_up = [] def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: """ Overview: sample data from ``GameBuffer`` and prepare the current and target batch for training. Arguments: - batch_size (:obj:`int`): batch size. - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. Returns: - train_data (:obj:`List`): List of train data, including current_batch and target_batch. """ policy._target_model.to(self._cfg.device) policy._target_model.eval() # obtain the current_batch and prepare target context reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( batch_size, self._cfg.reanalyze_ratio ) # target reward, target value batch_rewards, batch_target_values = self._compute_target_reward_value( reward_value_context, policy._target_model ) # target policy batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( policy_non_re_context, self._cfg.model.action_space_size ) # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies if 0 < self._cfg.reanalyze_ratio < 1: batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) elif self._cfg.reanalyze_ratio == 1: batch_target_policies = batch_target_policies_re elif self._cfg.reanalyze_ratio == 0: batch_target_policies = batch_target_policies_non_re target_batch = [batch_rewards, batch_target_values, batch_target_policies] # a batch contains the current_batch and the target_batch train_data = [current_batch, target_batch] return train_data def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: """ Overview: first sample orig_data through ``_sample_orig_data()``, then prepare the context of a batch: reward_value_context: the context of reanalyzed value targets policy_re_context: the context of reanalyzed policy targets policy_non_re_context: the context of non-reanalyzed policy targets current_batch: the inputs of batch Arguments: - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) Returns: - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch """ # obtain the batch context from replay buffer orig_data = self._sample_orig_data(batch_size) game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data batch_size = len(batch_index_list) obs_list, action_list, mask_list = [], [], [] # prepare the inputs of a batch for i in range(batch_size): game = game_segment_list[i] pos_in_game_segment = pos_in_game_segment_list[i] actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid mask_tmp = [1. for i in range(len(actions_tmp))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] # pad random action actions_tmp += [ np.random.randint(0, game.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] # obtain the input observations # pad if length of obs in game_segment is less than stack+num_unroll_steps # e.g. stack+num_unroll_steps = 4+5 obs_list.append( game_segment_list[i].get_unroll_obs( pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True ) ) action_list.append(actions_tmp) mask_list.append(mask_tmp) # formalize the input observations obs_list = prepare_observation(obs_list, self._cfg.model.model_type) # formalize the inputs of a batch current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] for i in range(len(current_batch)): current_batch[i] = np.asarray(current_batch[i]) total_transitions = self.get_num_of_transitions() # obtain the context of value targets reward_value_context = self._prepare_reward_value_context( batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions ) """ only reanalyze recent reanalyze_ratio (e.g. 50%) data if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy """ reanalyze_num = int(batch_size * reanalyze_ratio) # reanalyzed policy if reanalyze_num > 0: # obtain the context of reanalyzed policy targets policy_re_context = self._prepare_policy_reanalyzed_context( batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num], pos_in_game_segment_list[:reanalyze_num] ) else: policy_re_context = None # non reanalyzed policy if reanalyze_num < batch_size: # obtain the context of non-reanalyzed policy targets policy_non_re_context = self._prepare_policy_non_reanalyzed_context( batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:], pos_in_game_segment_list[reanalyze_num:] ) else: policy_non_re_context = None context = reward_value_context, policy_re_context, policy_non_re_context, current_batch return context def _prepare_reward_value_context( self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], total_transitions: int ) -> List[Any]: """ Overview: prepare the context of rewards and values for calculating TD value target in reanalyzing part. Arguments: - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer - game_segment_list (:obj:`list`): list of game segments - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment - total_transitions (:obj:`int`): number of collected transitions Returns: - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment """ zero_obs = game_segment_list[0].zero_obs() value_obs_list = [] # the value is valid or not (out of game_segment) value_mask = [] rewards_list = [] game_segment_lens = [] # for board games action_mask_segment, to_play_segment = [], [] td_steps_list = [] for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): game_segment_len = len(game_segment) game_segment_lens.append(game_segment_len) td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32) # prepare the corresponding observations for bootstrapped values o_{t+k} # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps] # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14] game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps) rewards_list.append(game_segment.reward_segment) # for board games action_mask_segment.append(game_segment.action_mask_segment) to_play_segment.append(game_segment.to_play_segment) for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): # get the bootstrapped target obs td_steps_list.append(td_steps) # index of bootstrapped obs o_{t+td_steps} bootstrap_index = current_index + td_steps if bootstrap_index < game_segment_len: value_mask.append(1) # beg_index = bootstrap_index - (state_index + td_steps), max of beg_index is num_unroll_steps beg_index = current_index - state_index end_index = beg_index + self._cfg.model.frame_stack_num # the stacked obs in time t obs = game_obs[beg_index:end_index] else: value_mask.append(0) obs = zero_obs value_obs_list.append(obs) reward_value_context = [ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment ] return reward_value_context def _prepare_policy_non_reanalyzed_context( self, batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int] ) -> List[Any]: """ Overview: prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play Arguments: - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer - game_segment_list (:obj:`list`): list of game segments - pos_in_game_segment_list (:obj:`list`): list transition index in game Returns: - policy_non_re_context (:obj:`list`): pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment """ child_visits = [] game_segment_lens = [] # for board games action_mask_segment, to_play_segment = [], [] for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): game_segment_len = len(game_segment) game_segment_lens.append(game_segment_len) # for board games action_mask_segment.append(game_segment.action_mask_segment) to_play_segment.append(game_segment.to_play_segment) child_visits.append(game_segment.child_visit_segment) policy_non_re_context = [ pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment ] return policy_non_re_context def _prepare_policy_reanalyzed_context( self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] ) -> List[Any]: """ Overview: prepare the context of policies for calculating policy target in reanalyzing part. Arguments: - batch_index_list (:obj:'list'): start transition index in the replay buffer - game_segment_list (:obj:'list'): list of game segments - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history Returns: - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices, child_visits, game_segment_lens, action_mask_segment, to_play_segment """ zero_obs = game_segment_list[0].zero_obs() with torch.no_grad(): # for policy policy_obs_list = [] policy_mask = [] # 0 -> Invalid target policy for padding outside of game segments, # 1 -> Previous target policy for game segments. rewards, child_visits, game_segment_lens = [], [], [] # for board games action_mask_segment, to_play_segment = [], [] for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): game_segment_len = len(game_segment) game_segment_lens.append(game_segment_len) rewards.append(game_segment.reward_segment) # for board games action_mask_segment.append(game_segment.action_mask_segment) to_play_segment.append(game_segment.to_play_segment) child_visits.append(game_segment.child_visit_segment) # prepare the corresponding observations game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): if current_index < game_segment_len: policy_mask.append(1) beg_index = current_index - state_index end_index = beg_index + self._cfg.model.frame_stack_num obs = game_obs[beg_index:end_index] else: policy_mask.append(0) obs = zero_obs policy_obs_list.append(obs) policy_re_context = [ policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment ] return policy_re_context def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> Tuple[Any, Any]: """ Overview: prepare reward and value targets from the context of rewards and values. Arguments: - reward_value_context (:obj:'list'): the reward value context - model (:obj:'torch.tensor'):model of the target model Returns: - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix - batch_target_values (:obj:'np.ndarray): batch of value estimation """ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ to_play_segment = reward_value_context # noqa # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) transition_batch_size = len(value_obs_list) game_segment_batch_size = len(pos_in_game_segment_list) to_play, action_mask = self._preprocess_to_play_and_action_mask( game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list ) if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] batch_target_values, batch_rewards = [], [] with torch.no_grad(): value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) network_output = [] for i in range(slices): beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() # calculate the target value m_output = model.initial_inference(m_obs) if not model.training: # if not in training, obtain the scalars of the value/reward [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), m_output.policy_logits ] ) network_output.append(m_output) # concat the output slices after model inference if self._cfg.use_root_value: # use the root values from MCTS, as in EfficiientZero # the root values have limited improvement but require much more GPU actors; _, reward_pool, policy_logits_pool, latent_state_roots = concat_output( network_output, data_type='muzero' ) reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() noises = [ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) ).astype(np.float32).tolist() for j in range(transition_batch_size) ] if self._cfg.mcts_ctree: # cpp mcts_tree roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) roots_values = roots.get_values() value_list = np.array(roots_values) else: # use the predicted values value_list = concat_output_value(network_output) # get last state value if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: # TODO(pu): for board_games, very important, to check value_list = value_list.reshape(-1) * np.array( [ self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] for i in range(transition_batch_size) ] ) else: value_list = value_list.reshape(-1) * ( np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list ) value_list = value_list * np.array(value_mask) value_list = value_list.tolist() horizon_id, value_index = 0, 0 for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, pos_in_game_segment_list, to_play_segment): target_values = [] target_rewards = [] base_index = state_index for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): bootstrap_index = current_index + td_steps_list[value_index] # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): for i, reward in enumerate(reward_list[current_index:bootstrap_index]): if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: # TODO(pu): for board_games, very important, to check if to_play_list[base_index] == to_play_list[i]: value_list[value_index] += reward * self._cfg.discount_factor ** i else: value_list[value_index] += -reward * self._cfg.discount_factor ** i else: value_list[value_index] += reward * self._cfg.discount_factor ** i horizon_id += 1 if current_index < game_segment_len_non_re: target_values.append(value_list[value_index]) target_rewards.append(reward_list[current_index]) else: target_values.append(0) target_rewards.append(0.0) # TODO: check # target_rewards.append(reward) value_index += 1 batch_rewards.append(target_rewards) batch_target_values.append(target_values) batch_rewards = np.asarray(batch_rewards, dtype=object) batch_target_values = np.asarray(batch_target_values, dtype=object) return batch_rewards, batch_target_values def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: """ Overview: prepare policy targets from the reanalyzed context of policies Arguments: - policy_re_context (:obj:`List`): List of policy context to reanalyzed Returns: - batch_target_policies_re """ if policy_re_context is None: return [] batch_target_policies_re = [] # for board games policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ to_play_segment = policy_re_context # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) transition_batch_size = len(policy_obs_list) game_segment_batch_size = len(pos_in_game_segment_list) to_play, action_mask = self._preprocess_to_play_and_action_mask( game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list ) if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] with torch.no_grad(): policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) network_output = [] for i in range(slices): beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float() m_output = model.initial_inference(m_obs) if not model.training: # if not in training, obtain the scalars of the value/reward [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ m_output.latent_state, inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), m_output.policy_logits ] ) network_output.append(m_output) _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() noises = [ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size ).astype(np.float32).tolist() for _ in range(transition_batch_size) ] if self._cfg.mcts_ctree: # cpp mcts_tree roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() policy_index = 0 for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list): target_policies = [] for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): distributions = roots_distributions[policy_index] if policy_mask[policy_index] == 0: # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) else: if distributions is None: # if at some obs, the legal_action is None, add the fake target_policy target_policies.append( list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) ) else: if self._cfg.action_type == 'fixed_action_space': # for atari/classic_control/box2d environments that only have one player. sum_visits = sum(distributions) policy = [visit_count / sum_visits for visit_count in distributions] target_policies.append(policy) else: # for board games that have two players and legal_actions is dy policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] # to make sure target_policies have the same dimension sum_visits = sum(distributions) policy = [visit_count / sum_visits for visit_count in distributions] for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): policy_tmp[legal_action] = policy[index] target_policies.append(policy_tmp) policy_index += 1 batch_target_policies_re.append(target_policies) batch_target_policies_re = np.array(batch_target_policies_re) return batch_target_policies_re def _compute_target_policy_non_reanalyzed( self, policy_non_re_context: List[Any], policy_shape: Optional[int] ) -> np.ndarray: """ Overview: prepare policy targets from the non-reanalyzed context of policies Arguments: - policy_non_re_context (:obj:`List`): List containing: - pos_in_game_segment_list - child_visits - game_segment_lens - action_mask_segment - to_play_segment - policy_shape: self._cfg.model.action_space_size Returns: - batch_target_policies_non_re """ batch_target_policies_non_re = [] if policy_non_re_context is None: return batch_target_policies_non_re pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment = policy_non_re_context game_segment_batch_size = len(pos_in_game_segment_list) transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) to_play, action_mask = self._preprocess_to_play_and_action_mask( game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list ) if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] with torch.no_grad(): policy_index = 0 # 0 -> Invalid target policy for padding outside of game segments, # 1 -> Previous target policy for game segments. policy_mask = [] for game_segment_len, child_visit, state_index in zip(game_segment_lens, child_visits, pos_in_game_segment_list): target_policies = [] for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): if current_index < game_segment_len: policy_mask.append(1) # NOTE: child_visit is already a distribution distributions = child_visit[current_index] if self._cfg.action_type == 'fixed_action_space': # for atari/classic_control/box2d environments that only have one player. target_policies.append(distributions) else: # for board games that have two players. policy_tmp = [0 for _ in range(policy_shape)] for index, legal_action in enumerate(legal_actions[policy_index]): # only the action in ``legal_action`` the policy logits is nonzero policy_tmp[legal_action] = distributions[index] target_policies.append(policy_tmp) else: # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0 policy_mask.append(0) target_policies.append([0 for _ in range(policy_shape)]) policy_index += 1 batch_target_policies_non_re.append(target_policies) batch_target_policies_non_re = np.asarray(batch_target_policies_non_re) return batch_target_policies_non_re def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: """ Overview: Update the priority of training data. Arguments: - train_data (:obj:`List[np.ndarray]`): training data to be updated priority. - batch_priorities (:obj:`batch_priorities`): priorities to update to. NOTE: train_data = [current_batch, target_batch] current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list] """ indices = train_data[0][-3] metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} # only update the priorities for data still in replay buffer for i in range(len(indices)): if metas['make_time'][i] > self.clear_time: idx, prio = indices[i], metas['batch_priorities'][i] self.game_pos_priorities[idx] = prio