import pytest import torch from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action @pytest.mark.unittest def test_vtrace_discrete_action(): T, B, N = 4, 8, 16 value = torch.randn(T + 1, B).requires_grad_(True) reward = torch.rand(T, B) target_output = torch.randn(T, B, N).requires_grad_(True) behaviour_output = torch.randn(T, B, N) action = torch.randint(0, N, size=(T, B)) data = vtrace_data(target_output, behaviour_output, action, value, reward, None) loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1) assert all([l.shape == tuple() for l in loss]) assert target_output.grad is None assert value.grad is None loss = sum(loss) loss.backward() assert isinstance(target_output, torch.Tensor) assert isinstance(value, torch.Tensor) @pytest.mark.unittest def test_vtrace_continuous_action(): T, B, N = 4, 8, 16 value = torch.randn(T + 1, B).requires_grad_(True) reward = torch.rand(T, B) target_output = {} target_output['mu'] = torch.randn(T, B, N).requires_grad_(True) target_output['sigma'] = torch.exp(torch.randn(T, B, N).requires_grad_(True)) behaviour_output = {} behaviour_output['mu'] = torch.randn(T, B, N) behaviour_output['sigma'] = torch.exp(torch.randn(T, B, N)) action = torch.randn((T, B, N)) data = vtrace_data(target_output, behaviour_output, action, value, reward, None) loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1) assert all([l.shape == tuple() for l in loss]) assert target_output['mu'].grad is None assert target_output['sigma'].grad is None assert value.grad is None loss = sum(loss) loss.backward() assert isinstance(target_output['mu'], torch.Tensor) assert isinstance(target_output['sigma'], torch.Tensor) assert isinstance(value, torch.Tensor)