import pytest import torch from ding.rl_utils import compute_q_retraces @pytest.mark.unittest def test_compute_q_retraces(): T, B, N = 64, 32, 6 q_values = torch.randn(T + 1, B, N) v_pred = torch.randn(T + 1, B, 1) rewards = torch.randn(T, B) ratio = torch.rand(T, B, N) * 0.4 + 0.8 assert ratio.max() <= 1.2 and ratio.min() >= 0.8 weights = torch.rand(T, B) actions = torch.randint(0, N, size=(T, B)) with torch.no_grad(): q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio, gamma=0.99) assert q_retraces.shape == (T + 1, B, 1)