|
import pytest |
|
from itertools import product |
|
import numpy as np |
|
import torch |
|
from ding.rl_utils import coma_data, coma_error |
|
|
|
random_weight = torch.rand(128, 4, 8) + 1 |
|
weight_args = [None, random_weight] |
|
|
|
|
|
@pytest.mark.unittest |
|
@pytest.mark.parametrize('weight, ', weight_args) |
|
def test_coma(weight): |
|
T, B, A, N = 128, 4, 8, 32 |
|
logit = torch.randn( |
|
T, |
|
B, |
|
A, |
|
N, |
|
).requires_grad_(True) |
|
action = torch.randint( |
|
0, N, size=( |
|
T, |
|
B, |
|
A, |
|
) |
|
) |
|
reward = torch.rand(T, B) |
|
q_value = torch.randn(T, B, A, N).requires_grad_(True) |
|
target_q_value = torch.randn(T, B, A, N).requires_grad_(True) |
|
mask = torch.randint(0, 2, (T, B, A)) |
|
data = coma_data(logit, action, q_value, target_q_value, reward, weight) |
|
loss = coma_error(data, 0.99, 0.95) |
|
assert all([l.shape == tuple() for l in loss]) |
|
assert logit.grad is None |
|
assert q_value.grad is None |
|
total_loss = sum(loss) |
|
total_loss.backward() |
|
assert isinstance(logit.grad, torch.Tensor) |
|
assert isinstance(q_value.grad, torch.Tensor) |
|
|