gomoku / LightZero /lzero /mcts /tests /test_game_segment.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
6.01 kB
import numpy as np
import pytest
import torch
from lzero.mcts.buffer.game_segment import GameSegment
from lzero.mcts.utils import prepare_observation
from lzero.policy import select_action
# args = ['EfficientZero', 'MuZero']
args = ["MuZero"]
@pytest.mark.unittest
@pytest.mark.parametrize('test_algo', args)
def test_game_segment(test_algo):
# import different modules according to ``test_algo``
if test_algo == 'EfficientZero':
from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree
from lzero.model.efficientzero_model import EfficientZeroModel as Model
from lzero.mcts.tests.config.atari_efficientzero_config_for_test import atari_efficientzero_config as config
from zoo.atari.envs.atari_lightzero_env import AtariLightZeroEnv
envs = [AtariLightZeroEnv(config.env) for _ in range(config.env.evaluator_env_num)]
elif test_algo == 'MuZero':
from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree
from lzero.model.muzero_model import MuZeroModel as Model
from lzero.mcts.tests.config.tictactoe_muzero_bot_mode_config_for_test import tictactoe_muzero_config as config
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
envs = [TicTacToeEnv(config.env) for _ in range(config.env.evaluator_env_num)]
# create model
model = Model(**config.policy.model)
if config.policy.cuda and torch.cuda.is_available():
config.policy.device = 'cuda'
else:
config.policy.device = 'cpu'
model.to(config.policy.device)
model.eval()
with torch.no_grad():
# initializations
init_observations = [env.reset() for env in envs]
dones = np.array([False for _ in range(config.env.evaluator_env_num)])
game_segments = [
GameSegment(
envs[i].action_space, game_segment_length=config.policy.game_segment_length, config=config.policy
) for i in range(config.env.evaluator_env_num)
]
for i in range(config.env.evaluator_env_num):
game_segments[i].reset(
[init_observations[i]['observation'] for _ in range(config.policy.model.frame_stack_num)]
)
episode_rewards = np.zeros(config.env.evaluator_env_num)
while not dones.all():
stack_obs = [game_segment.get_obs() for game_segment in game_segments]
stack_obs = prepare_observation(stack_obs, config.policy.model.model_type)
stack_obs = torch.from_numpy(np.array(stack_obs)).to(config.policy.device)
# ==============================================================
# the core initial_inference.
# ==============================================================
network_output = model.initial_inference(stack_obs)
# process the network output
policy_logits_pool = network_output.policy_logits.detach().cpu().numpy().tolist()
latent_state_roots = network_output.latent_state.detach().cpu().numpy()
if test_algo == 'EfficientZero':
reward_hidden_state_roots = network_output.reward_hidden_state
value_prefix_pool = network_output.value_prefix
reward_hidden_state_roots = (
reward_hidden_state_roots[0].detach().cpu().numpy(),
reward_hidden_state_roots[1].detach().cpu().numpy()
)
# for atari env, all actions is legal_action
legal_actions_list = [
[i for i in range(config.policy.model.action_space_size)]
for _ in range(config.env.evaluator_env_num)
]
elif test_algo == 'MuZero':
reward_pool = network_output.reward
# for board games, we use the all actions is legal_action
legal_actions_list = [
[a for a, x in enumerate(init_observations[i]['action_mask']) if x == 1]
for i in range(config.env.evaluator_env_num)
]
# null padding for the atari games and board_games in vs_bot_mode
to_play = [-1 for _ in range(config.env.evaluator_env_num)]
if test_algo == 'EfficientZero':
roots = MCTSCtree.roots(config.env.evaluator_env_num, legal_actions_list)
roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play)
MCTSCtree(config.policy).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
elif test_algo == 'MuZero':
roots = MCTSCtree.roots(config.env.evaluator_env_num, legal_actions_list)
roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play)
MCTSCtree(config.policy).search(roots, model, latent_state_roots, to_play)
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
for i in range(config.env.evaluator_env_num):
distributions, value, env = roots_distributions[i], roots_values[i], envs[i]
# ``deterministic=True`` indicates that we select the argmax action instead of sampling.
action, _ = select_action(distributions, temperature=1, deterministic=True)
# ==============================================================
# the core initial_inference.
# ==============================================================
obs, reward, done, info = env.step(action)
obs = obs['observation']
game_segments[i].store_search_stats(distributions, value)
game_segments[i].append(action, obs, reward)
dones[i] = done
episode_rewards[i] += reward
if dones[i]:
continue
for env in envs:
env.close()