gomoku / LightZero /lzero /mcts /tests /test_game_buffer.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
2.81 kB
import numpy as np
import pytest
from easydict import EasyDict
from ding.torch_utils import to_list
from lzero.mcts.buffer.game_buffer_efficientzero import EfficientZeroGameBuffer
config = EasyDict(
dict(
batch_size=10,
transition_num=20,
priority_prob_alpha=0.6,
priority_prob_beta=0.4,
replay_buffer_size=10000,
env_type='not_board_games',
use_priority=True,
action_type='fixed_action_space',
)
)
@pytest.mark.unittest
def test_push():
buffer = EfficientZeroGameBuffer(config)
# fake data
data = [[1, 1, 1] for _ in range(10)] # (s,a,r)
meta = {'done': True, 'unroll_plus_td_steps': 5, 'priorities': np.array([0.9 for i in range(10)])}
# _push_game_segment
for i in range(20):
buffer._push_game_segment(to_list(np.multiply(i, data)), meta)
assert buffer.get_num_of_game_segments() == 20
# push_game_segments
buffer.push_game_segments([[data, data], [meta, meta]])
assert buffer.get_num_of_game_segments() == 22
# Clear
del buffer.game_segment_buffer[:]
assert buffer.get_num_of_game_segments() == 0
# _push_game_segment
for i in range(5):
buffer._push_game_segment(to_list(np.multiply(i, data)), meta)
@pytest.mark.unittest
def test_update_priority():
buffer = EfficientZeroGameBuffer(config)
# fake data
data = [[1, 1, 1] for _ in range(10)] # (s,a,r)
meta = {'done': True, 'unroll_plus_td_steps': 5, 'priorities': np.array([0.9 for i in range(10)])}
# _push_game_segment
for i in range(20):
buffer._push_game_segment(to_list(np.multiply(i, data)), meta)
assert buffer.get_num_of_game_segments() == 20
# fake data
indices = [0, 1]
make_time = [999, 1000]
train_data = [[[], [], [], indices, [], make_time], []]
# train_data = [current_batch, target_batch]
# current_batch = [obs_lst, action_lst, mask_lst, batch_index_list, weights, make_time_lst]
batch_priorities = [0.999, 0.8]
buffer.update_priority(train_data, batch_priorities)
assert buffer.game_pos_priorities[0] == 0.999
@pytest.mark.unittest
def test_sample_orig_data():
buffer = EfficientZeroGameBuffer(config)
# fake data
data_1 = [[1, 1, 1] for i in range(10)] # (s,a,r)
meta_1 = {'done': True, 'unroll_plus_td_steps': 5, 'priorities': np.array([0.9 for i in range(10)])}
data_2 = [[1, 1, 1] for i in range(10, 20)] # (s,a,r)
meta_2 = {'done': True, 'unroll_plus_td_steps': 5, 'priorities': np.array([0.9 for i in range(10)])}
# push
buffer._push_game_segment(data_1, meta_1)
buffer._push_game_segment(data_2, meta_2)
context = buffer._sample_orig_data(batch_size=2)
# context = (game_lst, game_pos_lst, indices_lst, weights, make_time)
print(context)