File size: 6,008 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import numpy as np
import pytest
import torch

from lzero.mcts.buffer.game_segment import GameSegment
from lzero.mcts.utils import prepare_observation
from lzero.policy import select_action

# args = ['EfficientZero', 'MuZero']
args = ["MuZero"]


@pytest.mark.unittest
@pytest.mark.parametrize('test_algo', args)
def test_game_segment(test_algo):
    # import different modules according to ``test_algo``
    if test_algo == 'EfficientZero':
        from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree
        from lzero.model.efficientzero_model import EfficientZeroModel as Model
        from lzero.mcts.tests.config.atari_efficientzero_config_for_test import atari_efficientzero_config as config
        from zoo.atari.envs.atari_lightzero_env import AtariLightZeroEnv
        envs = [AtariLightZeroEnv(config.env) for _ in range(config.env.evaluator_env_num)]

    elif test_algo == 'MuZero':
        from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree
        from lzero.model.muzero_model import MuZeroModel as Model
        from lzero.mcts.tests.config.tictactoe_muzero_bot_mode_config_for_test import tictactoe_muzero_config as config
        from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
        envs = [TicTacToeEnv(config.env) for _ in range(config.env.evaluator_env_num)]

    # create model
    model = Model(**config.policy.model)
    if config.policy.cuda and torch.cuda.is_available():
        config.policy.device = 'cuda'
    else:
        config.policy.device = 'cpu'
    model.to(config.policy.device)
    model.eval()

    with torch.no_grad():
        # initializations
        init_observations = [env.reset() for env in envs]
        dones = np.array([False for _ in range(config.env.evaluator_env_num)])
        game_segments = [
            GameSegment(
                envs[i].action_space, game_segment_length=config.policy.game_segment_length, config=config.policy
            ) for i in range(config.env.evaluator_env_num)
        ]
        for i in range(config.env.evaluator_env_num):
            game_segments[i].reset(
                [init_observations[i]['observation'] for _ in range(config.policy.model.frame_stack_num)]
            )
        episode_rewards = np.zeros(config.env.evaluator_env_num)

        while not dones.all():
            stack_obs = [game_segment.get_obs() for game_segment in game_segments]
            stack_obs = prepare_observation(stack_obs, config.policy.model.model_type)
            stack_obs = torch.from_numpy(np.array(stack_obs)).to(config.policy.device)

            # ==============================================================
            # the core initial_inference.
            # ==============================================================
            network_output = model.initial_inference(stack_obs)

            # process the network output
            policy_logits_pool = network_output.policy_logits.detach().cpu().numpy().tolist()
            latent_state_roots = network_output.latent_state.detach().cpu().numpy()

            if test_algo == 'EfficientZero':
                reward_hidden_state_roots = network_output.reward_hidden_state
                value_prefix_pool = network_output.value_prefix
                reward_hidden_state_roots = (
                    reward_hidden_state_roots[0].detach().cpu().numpy(),
                    reward_hidden_state_roots[1].detach().cpu().numpy()
                )
                # for atari env, all actions is legal_action
                legal_actions_list = [
                    [i for i in range(config.policy.model.action_space_size)]
                    for _ in range(config.env.evaluator_env_num)
                ]
            elif test_algo == 'MuZero':
                reward_pool = network_output.reward
                # for board games, we use the all actions is legal_action
                legal_actions_list = [
                    [a for a, x in enumerate(init_observations[i]['action_mask']) if x == 1]
                    for i in range(config.env.evaluator_env_num)
                ]

            # null padding for the atari games and board_games in vs_bot_mode
            to_play = [-1 for _ in range(config.env.evaluator_env_num)]

            if test_algo == 'EfficientZero':
                roots = MCTSCtree.roots(config.env.evaluator_env_num, legal_actions_list)
                roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play)
                MCTSCtree(config.policy).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)

            elif test_algo == 'MuZero':
                roots = MCTSCtree.roots(config.env.evaluator_env_num, legal_actions_list)
                roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play)
                MCTSCtree(config.policy).search(roots, model, latent_state_roots, to_play)

            roots_distributions = roots.get_distributions()
            roots_values = roots.get_values()

            for i in range(config.env.evaluator_env_num):
                distributions, value, env = roots_distributions[i], roots_values[i], envs[i]
                # ``deterministic=True``  indicates that we select the argmax action instead of sampling.
                action, _ = select_action(distributions, temperature=1, deterministic=True)
                # ==============================================================
                # the core initial_inference.
                # ==============================================================
                obs, reward, done, info = env.step(action)
                obs = obs['observation']

                game_segments[i].store_search_stats(distributions, value)
                game_segments[i].append(action, obs, reward)

                dones[i] = done
                episode_rewards[i] += reward
                if dones[i]:
                    continue

        for env in envs:
            env.close()