zjowowen's picture
init space
079c32c
raw
history blame
1.9 kB
import pytest
import torch
from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action
@pytest.mark.unittest
def test_vtrace_discrete_action():
T, B, N = 4, 8, 16
value = torch.randn(T + 1, B).requires_grad_(True)
reward = torch.rand(T, B)
target_output = torch.randn(T, B, N).requires_grad_(True)
behaviour_output = torch.randn(T, B, N)
action = torch.randint(0, N, size=(T, B))
data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1)
assert all([l.shape == tuple() for l in loss])
assert target_output.grad is None
assert value.grad is None
loss = sum(loss)
loss.backward()
assert isinstance(target_output, torch.Tensor)
assert isinstance(value, torch.Tensor)
@pytest.mark.unittest
def test_vtrace_continuous_action():
T, B, N = 4, 8, 16
value = torch.randn(T + 1, B).requires_grad_(True)
reward = torch.rand(T, B)
target_output = {}
target_output['mu'] = torch.randn(T, B, N).requires_grad_(True)
target_output['sigma'] = torch.exp(torch.randn(T, B, N).requires_grad_(True))
behaviour_output = {}
behaviour_output['mu'] = torch.randn(T, B, N)
behaviour_output['sigma'] = torch.exp(torch.randn(T, B, N))
action = torch.randn((T, B, N))
data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1)
assert all([l.shape == tuple() for l in loss])
assert target_output['mu'].grad is None
assert target_output['sigma'].grad is None
assert value.grad is None
loss = sum(loss)
loss.backward()
assert isinstance(target_output['mu'], torch.Tensor)
assert isinstance(target_output['sigma'], torch.Tensor)
assert isinstance(value, torch.Tensor)