import pytest from itertools import product import numpy as np import torch from ding.rl_utils import a2c_data, a2c_error, a2c_error_continuous random_weight = torch.rand(4) + 1 weight_args = [None, random_weight] @pytest.mark.unittest @pytest.mark.parametrize('weight, ', weight_args) def test_a2c(weight): B, N = 4, 32 logit = torch.randn(B, N).requires_grad_(True) action = torch.randint(0, N, size=(B, )) value = torch.randn(B).requires_grad_(True) adv = torch.rand(B) return_ = torch.randn(B) * 2 data = a2c_data(logit, action, value, adv, return_, weight) loss = a2c_error(data) assert all([l.shape == tuple() for l in loss]) assert logit.grad is None assert value.grad is None total_loss = sum(loss) total_loss.backward() assert isinstance(logit.grad, torch.Tensor) assert isinstance(value.grad, torch.Tensor) @pytest.mark.unittest @pytest.mark.parametrize('weight, ', weight_args) def test_a2c_continuous(weight): B, N = 4, 32 logit = { "mu": torch.randn(B, N).requires_grad_(True), "sigma": torch.exp(torch.randn(B, N)).requires_grad_(True), } action = torch.randn(B, N).requires_grad_(True) value = torch.randn(B).requires_grad_(True) adv = torch.rand(B) return_ = torch.randn(B) * 2 data = a2c_data(logit, action, value, adv, return_, weight) loss = a2c_error_continuous(data) assert all([l.shape == tuple() for l in loss]) assert logit["mu"].grad is None assert logit["sigma"].grad is None assert value.grad is None total_loss = sum(loss) total_loss.backward() assert isinstance(logit["mu"].grad, torch.Tensor) assert isinstance(logit['sigma'].grad, torch.Tensor) assert isinstance(value.grad, torch.Tensor)