zjowowen's picture
init space
079c32c
raw
history blame
1.12 kB
import pytest
from itertools import product
import numpy as np
import torch
from ding.rl_utils import coma_data, coma_error
random_weight = torch.rand(128, 4, 8) + 1
weight_args = [None, random_weight]
@pytest.mark.unittest
@pytest.mark.parametrize('weight, ', weight_args)
def test_coma(weight):
T, B, A, N = 128, 4, 8, 32
logit = torch.randn(
T,
B,
A,
N,
).requires_grad_(True)
action = torch.randint(
0, N, size=(
T,
B,
A,
)
)
reward = torch.rand(T, B)
q_value = torch.randn(T, B, A, N).requires_grad_(True)
target_q_value = torch.randn(T, B, A, N).requires_grad_(True)
mask = torch.randint(0, 2, (T, B, A))
data = coma_data(logit, action, q_value, target_q_value, reward, weight)
loss = coma_error(data, 0.99, 0.95)
assert all([l.shape == tuple() for l in loss])
assert logit.grad is None
assert q_value.grad is None
total_loss = sum(loss)
total_loss.backward()
assert isinstance(logit.grad, torch.Tensor)
assert isinstance(q_value.grad, torch.Tensor)