|
import pytest |
|
import torch |
|
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data,\ |
|
td_lambda_error, q_nstep_td_error_with_rescale, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_data,\ |
|
dqfd_nstep_td_data, dqfd_nstep_td_error, dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data,\ |
|
v_nstep_td_error, q_nstep_sql_td_error, iqn_nstep_td_data, iqn_nstep_td_error,\ |
|
fqf_nstep_td_data, fqf_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error, bdq_nstep_td_error,\ |
|
m_q_1step_td_data, m_q_1step_td_error |
|
from ding.rl_utils.td import shape_fn_dntd, shape_fn_qntd, shape_fn_td_lambda, shape_fn_qntd_rescale |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_q_nstep_td(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
next_q = torch.randn(batch_size, action_dim) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
for nstep in range(1, 10): |
|
q = torch.randn(batch_size, action_dim).requires_grad_(True) |
|
reward = torch.rand(nstep, batch_size) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep) |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
assert loss.shape == () |
|
assert q.grad is None |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True) |
|
value_gamma = torch.tensor(0.9) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True, value_gamma=value_gamma) |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_bdq_nstep_td(): |
|
batch_size = 8 |
|
branch_num = 6 |
|
action_per_branch = 3 |
|
next_q = torch.randn(batch_size, branch_num, action_per_branch) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_per_branch, size=(batch_size, branch_num)) |
|
next_action = torch.randint(0, action_per_branch, size=(batch_size, branch_num)) |
|
for nstep in range(1, 10): |
|
q = torch.randn(batch_size, branch_num, action_per_branch).requires_grad_(True) |
|
reward = torch.rand(nstep, batch_size) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep) |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
assert loss.shape == () |
|
assert q.grad is None |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True) |
|
value_gamma = torch.tensor(0.9) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, td_error_per_sample = bdq_nstep_td_error( |
|
data, 0.95, nstep=nstep, cum_reward=True, value_gamma=value_gamma |
|
) |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_q_nstep_td_ngu(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
next_q = torch.randn(batch_size, action_dim) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
gamma = [torch.tensor(0.95) for i in range(batch_size)] |
|
|
|
for nstep in range(1, 10): |
|
q = torch.randn(batch_size, action_dim).requires_grad_(True) |
|
reward = torch.rand(nstep, batch_size) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, td_error_per_sample = q_nstep_td_error(data, gamma, nstep=nstep) |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
assert loss.shape == () |
|
assert q.grad is None |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_dist_1step_td(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
n_atom = 51 |
|
v_min = -10.0 |
|
v_max = 10.0 |
|
dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) |
|
next_dist = torch.randn(batch_size, action_dim, n_atom).abs() |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
reward = torch.randn(batch_size) |
|
data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None) |
|
loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom) |
|
assert loss.shape == () |
|
assert dist.grad is None |
|
loss.backward() |
|
assert isinstance(dist.grad, torch.Tensor) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_q_1step_compatible(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
next_q = torch.randn(batch_size, action_dim) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
q = torch.randn(batch_size, action_dim).requires_grad_(True) |
|
reward = torch.rand(batch_size) |
|
nstep_data = q_nstep_td_data(q, next_q, action, next_action, reward.unsqueeze(0), done, None) |
|
onestep_data = q_1step_td_data(q, next_q, action, next_action, reward, done, None) |
|
nstep_loss, _ = q_nstep_td_error(nstep_data, 0.99, nstep=1) |
|
onestep_loss = q_1step_td_error(onestep_data, 0.99) |
|
assert pytest.approx(nstep_loss.item()) == onestep_loss.item() |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_dist_nstep_td(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
n_atom = 51 |
|
v_min = -10.0 |
|
v_max = 10.0 |
|
nstep = 5 |
|
dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) |
|
next_n_dist = torch.randn(batch_size, action_dim, n_atom).abs() |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
reward = torch.randn(nstep, batch_size) |
|
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) |
|
loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep) |
|
assert loss.shape == () |
|
assert dist.grad is None |
|
loss.backward() |
|
assert isinstance(dist.grad, torch.Tensor) |
|
weight = torch.tensor([0.9]) |
|
value_gamma = torch.tensor(0.9) |
|
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, weight) |
|
loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma) |
|
assert loss.shape == () |
|
loss.backward() |
|
assert isinstance(dist.grad, torch.Tensor) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_dist_nstep_multi_agent_td(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
agent_num = 2 |
|
n_atom = 51 |
|
v_min = -10.0 |
|
v_max = 10.0 |
|
nstep = 5 |
|
dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs().requires_grad_(True) |
|
next_n_dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs() |
|
done = torch.randint(0, 2, (batch_size, )) |
|
action = torch.randint( |
|
0, action_dim, size=( |
|
batch_size, |
|
agent_num, |
|
) |
|
) |
|
next_action = torch.randint( |
|
0, action_dim, size=( |
|
batch_size, |
|
agent_num, |
|
) |
|
) |
|
reward = torch.randn(nstep, batch_size) |
|
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) |
|
loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep) |
|
assert loss.shape == () |
|
assert dist.grad is None |
|
loss.backward() |
|
assert isinstance(dist.grad, torch.Tensor) |
|
weight = 0.9 |
|
value_gamma = 0.9 |
|
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, weight) |
|
loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma) |
|
assert loss.shape == () |
|
loss.backward() |
|
assert isinstance(dist.grad, torch.Tensor) |
|
agent_total_loss = 0 |
|
for i in range(agent_num): |
|
data = dist_nstep_td_data( |
|
dist[:, i, ], next_n_dist[:, i, ], action[:, i, ], next_action[:, i, ], reward, done, weight |
|
) |
|
agent_loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma) |
|
agent_total_loss = agent_total_loss + agent_loss |
|
agent_average_loss = agent_total_loss / agent_num |
|
assert abs(agent_average_loss.item() - loss.item()) < 1e-5 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_q_nstep_td_with_rescale(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
next_q = torch.randn(batch_size, action_dim) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
for nstep in range(1, 10): |
|
q = torch.randn(batch_size, action_dim).requires_grad_(True) |
|
reward = torch.rand(nstep, batch_size) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep) |
|
assert loss.shape == () |
|
assert q.grad is None |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
print(loss) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_q_nstep_td_with_rescale_ngu(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
next_q = torch.randn(batch_size, action_dim) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
gamma = [torch.tensor(0.95) for i in range(batch_size)] |
|
for nstep in range(1, 10): |
|
q = torch.randn(batch_size, action_dim).requires_grad_(True) |
|
reward = torch.rand(nstep, batch_size) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, _ = q_nstep_td_error_with_rescale(data, gamma, nstep=nstep) |
|
assert loss.shape == () |
|
assert q.grad is None |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
print(loss) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_qrdqn_nstep_td(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
tau = 3 |
|
next_q = torch.randn(batch_size, action_dim, tau) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
for nstep in range(1, 10): |
|
q = torch.randn(batch_size, action_dim, tau).requires_grad_(True) |
|
reward = torch.rand(nstep, batch_size) |
|
data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, tau, None) |
|
loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep) |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
assert loss.shape == () |
|
assert q.grad is None |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9)) |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_dist_1step_compatible(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
n_atom = 51 |
|
v_min = -10.0 |
|
v_max = 10.0 |
|
dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) |
|
next_dist = torch.randn(batch_size, action_dim, n_atom).abs() |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
reward = torch.randn(batch_size) |
|
onestep_data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None) |
|
nstep_data = dist_nstep_td_data(dist, next_dist, action, next_action, reward.unsqueeze(0), done, None) |
|
onestep_loss = dist_1step_td_error(onestep_data, 0.95, v_min, v_max, n_atom) |
|
nstep_loss, _ = dist_nstep_td_error(nstep_data, 0.95, v_min, v_max, n_atom, nstep=1) |
|
assert pytest.approx(nstep_loss.item()) == onestep_loss.item() |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_dist_1step_multi_agent_td(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
agent_num = 2 |
|
n_atom = 51 |
|
v_min = -10.0 |
|
v_max = 10.0 |
|
dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs().requires_grad_(True) |
|
next_dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs() |
|
done = torch.randint(0, 2, (batch_size, )) |
|
action = torch.randint( |
|
0, action_dim, size=( |
|
batch_size, |
|
agent_num, |
|
) |
|
) |
|
next_action = torch.randint( |
|
0, action_dim, size=( |
|
batch_size, |
|
agent_num, |
|
) |
|
) |
|
reward = torch.randn(batch_size) |
|
data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None) |
|
loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom) |
|
assert loss.shape == () |
|
assert dist.grad is None |
|
loss.backward() |
|
assert isinstance(dist.grad, torch.Tensor) |
|
agent_total_loss = 0 |
|
for i in range(agent_num): |
|
data = dist_1step_td_data( |
|
dist[:, i, ], next_dist[:, i, ], action[:, i, ], next_action[:, i, ], reward, done, None |
|
) |
|
agent_loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom) |
|
agent_total_loss = agent_total_loss + agent_loss |
|
agent_average_loss = agent_total_loss / agent_num |
|
assert abs(agent_average_loss.item() - loss.item()) < 1e-5 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_td_lambda(): |
|
T, B = 8, 4 |
|
value = torch.randn(T + 1, B).requires_grad_(True) |
|
reward = torch.rand(T, B) |
|
loss = td_lambda_error(td_lambda_data(value, reward, None)) |
|
assert loss.shape == () |
|
assert value.grad is None |
|
loss.backward() |
|
assert isinstance(value.grad, torch.Tensor) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_v_1step_td(): |
|
batch_size = 5 |
|
v = torch.randn(batch_size).requires_grad_(True) |
|
next_v = torch.randn(batch_size) |
|
reward = torch.rand(batch_size) |
|
done = torch.zeros(batch_size) |
|
data = v_1step_td_data(v, next_v, reward, done, None) |
|
loss, td_error_per_sample = v_1step_td_error(data, 0.99) |
|
assert loss.shape == () |
|
assert v.grad is None |
|
loss.backward() |
|
assert isinstance(v.grad, torch.Tensor) |
|
data = v_1step_td_data(v, next_v, reward, None, None) |
|
loss, td_error_per_sample = v_1step_td_error(data, 0.99) |
|
loss.backward() |
|
assert isinstance(v.grad, torch.Tensor) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_v_1step_multi_agent_td(): |
|
batch_size = 5 |
|
agent_num = 2 |
|
v = torch.randn(batch_size, agent_num).requires_grad_(True) |
|
next_v = torch.randn(batch_size, agent_num) |
|
reward = torch.rand(batch_size) |
|
done = torch.zeros(batch_size) |
|
data = v_1step_td_data(v, next_v, reward, done, None) |
|
loss, td_error_per_sample = v_1step_td_error(data, 0.99) |
|
assert loss.shape == () |
|
assert v.grad is None |
|
loss.backward() |
|
assert isinstance(v.grad, torch.Tensor) |
|
data = v_1step_td_data(v, next_v, reward, None, None) |
|
loss, td_error_per_sample = v_1step_td_error(data, 0.99) |
|
loss.backward() |
|
assert isinstance(v.grad, torch.Tensor) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_v_nstep_td(): |
|
batch_size = 5 |
|
v = torch.randn(batch_size).requires_grad_(True) |
|
next_v = torch.randn(batch_size) |
|
reward = torch.rand(5, batch_size) |
|
done = torch.zeros(batch_size) |
|
data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99) |
|
loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5) |
|
assert loss.shape == () |
|
assert v.grad is None |
|
loss.backward() |
|
assert isinstance(v.grad, torch.Tensor) |
|
data = v_nstep_td_data(v, next_v, reward, done, None, 0.99) |
|
loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5) |
|
loss.backward() |
|
assert isinstance(v.grad, torch.Tensor) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_dqfd_nstep_td(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
next_q = torch.randn(batch_size, action_dim) |
|
done = torch.randn(batch_size) |
|
done_1 = torch.randn(batch_size) |
|
next_q_one_step = torch.randn(batch_size, action_dim) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action_one_step = torch.randint(0, action_dim, size=(batch_size, )) |
|
is_expert = torch.ones((batch_size)) |
|
for nstep in range(1, 10): |
|
q = torch.randn(batch_size, action_dim).requires_grad_(True) |
|
reward = torch.rand(nstep, batch_size) |
|
data = dqfd_nstep_td_data( |
|
q, next_q, action, next_action, reward, done, done_1, None, next_q_one_step, next_action_one_step, is_expert |
|
) |
|
loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( |
|
data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, margin_function=0.8, nstep=nstep |
|
) |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
assert loss.shape == () |
|
assert q.grad is None |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
print(loss) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_q_nstep_sql_td(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
next_q = torch.randn(batch_size, action_dim) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
for nstep in range(1, 10): |
|
q = torch.randn(batch_size, action_dim).requires_grad_(True) |
|
reward = torch.rand(nstep, batch_size) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep) |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
assert loss.shape == () |
|
assert q.grad is None |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 0.5, nstep=nstep, cum_reward=True) |
|
value_gamma = torch.tensor(0.9) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error( |
|
data, 0.95, 0.5, nstep=nstep, cum_reward=True, value_gamma=value_gamma |
|
) |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_iqn_nstep_td(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
tau = 3 |
|
next_q = torch.randn(tau, batch_size, action_dim) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
for nstep in range(1, 10): |
|
q = torch.randn(tau, batch_size, action_dim).requires_grad_(True) |
|
replay_quantile = torch.randn([tau, batch_size, 1]) |
|
reward = torch.rand(nstep, batch_size) |
|
data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None) |
|
loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep) |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
assert loss.shape == () |
|
assert q.grad is None |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9)) |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_fqf_nstep_td(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
tau = 3 |
|
next_q = torch.randn(batch_size, tau, action_dim) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
for nstep in range(1, 10): |
|
q = torch.randn(batch_size, tau, action_dim).requires_grad_(True) |
|
quantiles_hats = torch.randn([batch_size, tau]) |
|
reward = torch.rand(nstep, batch_size) |
|
data = fqf_nstep_td_data(q, next_q, action, next_action, reward, done, quantiles_hats, None) |
|
loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep) |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
assert loss.shape == () |
|
assert q.grad is None |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9)) |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_shape_fn_qntd(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
next_q = torch.randn(batch_size, action_dim) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
for nstep in range(1, 10): |
|
q = torch.randn(batch_size, action_dim).requires_grad_(True) |
|
reward = torch.rand(nstep, batch_size) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
tmp = shape_fn_qntd([data, 0.95, 1], {}) |
|
assert tmp[0] == reward.shape[0] |
|
assert tmp[1] == q.shape[0] |
|
assert tmp[2] == q.shape[1] |
|
tmp = shape_fn_qntd([], {'gamma': 0.95, 'nstep': 1, 'data': data}) |
|
assert tmp[0] == reward.shape[0] |
|
assert tmp[1] == q.shape[0] |
|
assert tmp[2] == q.shape[1] |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_shape_fn_dntd(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
n_atom = 51 |
|
v_min = -10.0 |
|
v_max = 10.0 |
|
nstep = 5 |
|
dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) |
|
next_n_dist = torch.randn(batch_size, action_dim, n_atom).abs() |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
reward = torch.randn(nstep, batch_size) |
|
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) |
|
tmp = shape_fn_dntd([data, 0.9, v_min, v_max, n_atom, nstep], {}) |
|
assert tmp[0] == reward.shape[0] |
|
assert tmp[1] == dist.shape[0] |
|
assert tmp[2] == dist.shape[1] |
|
assert tmp[3] == n_atom |
|
tmp = shape_fn_dntd([], {'data': data, 'gamma': 0.9, 'v_min': v_min, 'v_max': v_max, 'n_atom': n_atom, 'nstep': 5}) |
|
assert tmp[0] == reward.shape[0] |
|
assert tmp[1] == dist.shape[0] |
|
assert tmp[2] == dist.shape[1] |
|
assert tmp[3] == n_atom |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_shape_fn_qntd_rescale(): |
|
batch_size = 4 |
|
action_dim = 3 |
|
next_q = torch.randn(batch_size, action_dim) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
next_action = torch.randint(0, action_dim, size=(batch_size, )) |
|
for nstep in range(1, 10): |
|
q = torch.randn(batch_size, action_dim).requires_grad_(True) |
|
reward = torch.rand(nstep, batch_size) |
|
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
|
tmp = shape_fn_qntd_rescale([data, 0.95, 1], {}) |
|
assert tmp[0] == reward.shape[0] |
|
assert tmp[1] == q.shape[0] |
|
assert tmp[2] == q.shape[1] |
|
tmp = shape_fn_qntd_rescale([], {'gamma': 0.95, 'nstep': 1, 'data': data}) |
|
assert tmp[0] == reward.shape[0] |
|
assert tmp[1] == q.shape[0] |
|
assert tmp[2] == q.shape[1] |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_fn_td_lambda(): |
|
T, B = 8, 4 |
|
value = torch.randn(T + 1, B).requires_grad_(True) |
|
reward = torch.rand(T, B) |
|
data = td_lambda_data(value, reward, None) |
|
tmp = shape_fn_td_lambda([], {'data': data}) |
|
assert tmp == reward.shape[0] |
|
tmp = shape_fn_td_lambda([data], {}) |
|
assert tmp == reward.shape |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_fn_m_q_1step_td_error(): |
|
batch_size = 128 |
|
action_dim = 9 |
|
q = torch.randn(batch_size, action_dim).requires_grad_(True) |
|
target_q_current = torch.randn(batch_size, action_dim).requires_grad_(False) |
|
target_q_next = torch.randn(batch_size, action_dim).requires_grad_(False) |
|
done = torch.randn(batch_size) |
|
action = torch.randint(0, action_dim, size=(batch_size, )) |
|
reward = torch.randn(batch_size) |
|
data = m_q_1step_td_data(q, target_q_current, target_q_next, action, reward, done, None) |
|
loss, td_error_per_sample, action_gap, clip_frac = m_q_1step_td_error(data, 0.99, 0.03, 0.6) |
|
|
|
assert loss.shape == () |
|
assert q.grad is None |
|
loss.backward() |
|
assert isinstance(q.grad, torch.Tensor) |
|
assert clip_frac.mean().item() <= 1 |
|
assert action_gap.item() > 0 |
|
assert td_error_per_sample.shape == (batch_size, ) |
|
|