from typing import Any, Tuple import numpy as np from ding.utils import BUFFER_REGISTRY from lzero.mcts.buffer import MuZeroGameBuffer from lzero.mcts.utils import prepare_observation @BUFFER_REGISTRY.register('game_buffer_gumbel_muzero') class GumbelMuZeroGameBuffer(MuZeroGameBuffer): """ Overview: The specific game buffer for Gumbel MuZero policy. """ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: """ Overview: first sample orig_data through ``_sample_orig_data()``, then prepare the context of a batch: reward_value_context: the context of reanalyzed value targets policy_re_context: the context of reanalyzed policy targets policy_non_re_context: the context of non-reanalyzed policy targets current_batch: the inputs of batch Arguments: - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) Returns: - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch """ # obtain the batch context from replay buffer orig_data = self._sample_orig_data(batch_size) game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data batch_size = len(batch_index_list) # ============================================================== # The core difference between GumbelMuZero and MuZero # ============================================================== # The main difference between Gumbel MuZero and MuZero lies in the preprocessing of improved_policy. obs_list, action_list, improved_policy_list, mask_list = [], [], [], [] # prepare the inputs of a batch for i in range(batch_size): game = game_segment_list[i] pos_in_game_segment = pos_in_game_segment_list[i] actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() _improved_policy = game.improved_policy_probs[ pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps] if not isinstance(_improved_policy, list): _improved_policy = _improved_policy.tolist() # add mask for invalid actions (out of trajectory) mask_tmp = [1. for i in range(len(actions_tmp))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] # pad random action actions_tmp += [ np.random.randint(0, game.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] # pad improved policy with a value such that the sum of the values is equal to 1 _improved_policy.extend(np.random.dirichlet(np.ones(game.action_space_size), size=self._cfg.num_unroll_steps + 1 - len(_improved_policy))) # obtain the input observations # pad if length of obs in game_segment is less than stack+num_unroll_steps # e.g. stack+num_unroll_steps = 4+5 obs_list.append( game_segment_list[i].get_unroll_obs( pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True ) ) action_list.append(actions_tmp) improved_policy_list.append(_improved_policy) mask_list.append(mask_tmp) # formalize the input observations obs_list = prepare_observation(obs_list, self._cfg.model.model_type) # formalize the inputs of a batch current_batch = [obs_list, action_list, improved_policy_list, mask_list, batch_index_list, weights_list, make_time_list] for i in range(len(current_batch)): current_batch[i] = np.asarray(current_batch[i]) total_transitions = self.get_num_of_transitions() # obtain the context of value targets reward_value_context = self._prepare_reward_value_context( batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions ) """ only reanalyze recent reanalyze_ratio (e.g. 50%) data if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy """ reanalyze_num = int(batch_size * reanalyze_ratio) # reanalyzed policy if reanalyze_num > 0: # obtain the context of reanalyzed policy targets policy_re_context = self._prepare_policy_reanalyzed_context( batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num], pos_in_game_segment_list[:reanalyze_num] ) else: policy_re_context = None # non reanalyzed policy if reanalyze_num < batch_size: # obtain the context of non-reanalyzed policy targets policy_non_re_context = self._prepare_policy_non_reanalyzed_context( batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:], pos_in_game_segment_list[reanalyze_num:] ) else: policy_non_re_context = None context = reward_value_context, policy_re_context, policy_non_re_context, current_batch return context