|
from itertools import product |
|
|
|
import pytest |
|
import torch |
|
from ding.torch_utils import is_differentiable |
|
|
|
from lzero.model.alphazero_model import PredictionNetwork |
|
|
|
action_space_size = [2, 3] |
|
batch_size = [100, 200] |
|
num_res_blocks = [3] |
|
num_channels = [3] |
|
value_head_channels = [8] |
|
policy_head_channels = [8] |
|
fc_value_layers = [[ |
|
16, |
|
]] |
|
fc_policy_layers = [[ |
|
16, |
|
]] |
|
output_support_size = [2] |
|
observation_shape = [1, 3, 3] |
|
|
|
prediction_network_args = list( |
|
product( |
|
action_space_size, |
|
batch_size, |
|
num_res_blocks, |
|
num_channels, |
|
value_head_channels, |
|
policy_head_channels, |
|
fc_value_layers, |
|
fc_policy_layers, |
|
output_support_size, |
|
) |
|
) |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestAlphaZeroModel: |
|
|
|
def output_check(self, model, outputs): |
|
if isinstance(outputs, torch.Tensor): |
|
loss = outputs.sum() |
|
elif isinstance(outputs, list): |
|
loss = sum([t.sum() for t in outputs]) |
|
elif isinstance(outputs, dict): |
|
loss = sum([v.sum() for v in outputs.values()]) |
|
is_differentiable(loss, model) |
|
|
|
@pytest.mark.parametrize( |
|
'action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels, policy_head_channels, fc_value_layers, fc_policy_layers, output_support_size', |
|
prediction_network_args |
|
) |
|
def test_prediction_network( |
|
self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels, |
|
policy_head_channels, |
|
fc_value_layers, fc_policy_layers, output_support_size |
|
): |
|
obs = torch.rand(batch_size, num_channels, 3, 3) |
|
flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2] |
|
flatten_output_size_for_policy_head = policy_head_channels * observation_shape[1] * observation_shape[2] |
|
|
|
|
|
|
|
prediction_network = PredictionNetwork( |
|
action_space_size=action_space_size, |
|
continuous_action_space=False, |
|
num_res_blocks=num_res_blocks, |
|
num_channels=num_channels, |
|
value_head_channels=value_head_channels, |
|
policy_head_channels=policy_head_channels, |
|
fc_value_layers=fc_value_layers, |
|
fc_policy_layers=fc_policy_layers, |
|
output_support_size=output_support_size, |
|
flatten_output_size_for_value_head=flatten_output_size_for_value_head, |
|
flatten_output_size_for_policy_head=flatten_output_size_for_policy_head, |
|
last_linear_layer_init_zero=True, |
|
) |
|
policy, value = prediction_network(obs) |
|
assert policy.shape == torch.Size([batch_size, action_space_size]) |
|
assert value.shape == torch.Size([batch_size, output_support_size]) |
|
|
|
|
|
if __name__ == "__main__": |
|
action_space_size = 2 |
|
batch_size = 100 |
|
num_res_blocks = 3 |
|
num_channels = 3 |
|
reward_head_channels = 2 |
|
value_head_channels = 8 |
|
policy_head_channels = 8 |
|
fc_value_layers = [16] |
|
fc_policy_layers = [16] |
|
output_support_size = 2 |
|
observation_shape = [1, 3, 3] |
|
obs = torch.rand(batch_size, num_channels, 3, 3) |
|
flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2] |
|
flatten_output_size_for_policy_head = policy_head_channels * observation_shape[1] * observation_shape[2] |
|
print('=' * 20) |
|
print( |
|
batch_size, num_res_blocks, num_channels, action_space_size, reward_head_channels, fc_value_layers, |
|
fc_policy_layers, output_support_size |
|
) |
|
print('=' * 20) |
|
prediction_network = PredictionNetwork( |
|
action_space_size=action_space_size, |
|
num_res_blocks=num_res_blocks, |
|
num_channels=num_channels, |
|
value_head_channels=value_head_channels, |
|
policy_head_channels=policy_head_channels, |
|
fc_value_layers=fc_value_layers, |
|
fc_policy_layers=fc_policy_layers, |
|
output_support_size=output_support_size, |
|
flatten_output_size_for_value_head=flatten_output_size_for_value_head, |
|
flatten_output_size_for_policy_head=flatten_output_size_for_policy_head, |
|
last_linear_layer_init_zero=True, |
|
) |
|
policy, value = prediction_network(obs) |
|
assert policy.shape == torch.Size([batch_size, action_space_size]) |
|
assert value.shape == torch.Size([batch_size, output_support_size]) |
|
|