gomoku / LightZero /lzero /mcts /buffer /game_buffer_gumbel_muzero.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
5.87 kB
from typing import Any, Tuple
import numpy as np
from ding.utils import BUFFER_REGISTRY
from lzero.mcts.buffer import MuZeroGameBuffer
from lzero.mcts.utils import prepare_observation
@BUFFER_REGISTRY.register('game_buffer_gumbel_muzero')
class GumbelMuZeroGameBuffer(MuZeroGameBuffer):
"""
Overview:
The specific game buffer for Gumbel MuZero policy.
"""
def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
"""
Overview:
first sample orig_data through ``_sample_orig_data()``,
then prepare the context of a batch:
reward_value_context: the context of reanalyzed value targets
policy_re_context: the context of reanalyzed policy targets
policy_non_re_context: the context of non-reanalyzed policy targets
current_batch: the inputs of batch
Arguments:
- batch_size (:obj:`int`): the batch size of orig_data from replay buffer.
- reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed)
Returns:
- context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
"""
# obtain the batch context from replay buffer
orig_data = self._sample_orig_data(batch_size)
game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data
batch_size = len(batch_index_list)
# ==============================================================
# The core difference between GumbelMuZero and MuZero
# ==============================================================
# The main difference between Gumbel MuZero and MuZero lies in the preprocessing of improved_policy.
obs_list, action_list, improved_policy_list, mask_list = [], [], [], []
# prepare the inputs of a batch
for i in range(batch_size):
game = game_segment_list[i]
pos_in_game_segment = pos_in_game_segment_list[i]
actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()
_improved_policy = game.improved_policy_probs[
pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps]
if not isinstance(_improved_policy, list):
_improved_policy = _improved_policy.tolist()
# add mask for invalid actions (out of trajectory)
mask_tmp = [1. for i in range(len(actions_tmp))]
mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))]
# pad random action
actions_tmp += [
np.random.randint(0, game.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
]
# pad improved policy with a value such that the sum of the values is equal to 1
_improved_policy.extend(np.random.dirichlet(np.ones(game.action_space_size),
size=self._cfg.num_unroll_steps + 1 - len(_improved_policy)))
# obtain the input observations
# pad if length of obs in game_segment is less than stack+num_unroll_steps
# e.g. stack+num_unroll_steps = 4+5
obs_list.append(
game_segment_list[i].get_unroll_obs(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
)
)
action_list.append(actions_tmp)
improved_policy_list.append(_improved_policy)
mask_list.append(mask_tmp)
# formalize the input observations
obs_list = prepare_observation(obs_list, self._cfg.model.model_type)
# formalize the inputs of a batch
current_batch = [obs_list, action_list, improved_policy_list, mask_list, batch_index_list, weights_list,
make_time_list]
for i in range(len(current_batch)):
current_batch[i] = np.asarray(current_batch[i])
total_transitions = self.get_num_of_transitions()
# obtain the context of value targets
reward_value_context = self._prepare_reward_value_context(
batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions
)
"""
only reanalyze recent reanalyze_ratio (e.g. 50%) data
if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps
0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy
"""
reanalyze_num = int(batch_size * reanalyze_ratio)
# reanalyzed policy
if reanalyze_num > 0:
# obtain the context of reanalyzed policy targets
policy_re_context = self._prepare_policy_reanalyzed_context(
batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num],
pos_in_game_segment_list[:reanalyze_num]
)
else:
policy_re_context = None
# non reanalyzed policy
if reanalyze_num < batch_size:
# obtain the context of non-reanalyzed policy targets
policy_non_re_context = self._prepare_policy_non_reanalyzed_context(
batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:],
pos_in_game_segment_list[reanalyze_num:]
)
else:
policy_non_re_context = None
context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
return context