|
import pytest |
|
import copy |
|
from collections import deque |
|
import numpy as np |
|
import torch |
|
from ding.rl_utils import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestAdder: |
|
|
|
def get_transition(self): |
|
return { |
|
'value': torch.randn(1), |
|
'reward': torch.rand(1), |
|
'action': torch.rand(3), |
|
'other': np.random.randint(0, 10, size=(4, )), |
|
'obs': torch.randn(3), |
|
'done': False |
|
} |
|
|
|
def get_transition_multi_agent(self): |
|
return { |
|
'value': torch.randn(1, 8), |
|
'reward': torch.rand(1, 1), |
|
'action': torch.rand(3), |
|
'other': np.random.randint(0, 10, size=(4, )), |
|
'obs': torch.randn(3), |
|
'done': False |
|
} |
|
|
|
def test_get_gae(self): |
|
transitions = deque([self.get_transition() for _ in range(10)]) |
|
last_value = torch.randn(1) |
|
output = get_gae(transitions, last_value, gamma=0.99, gae_lambda=0.97, cuda=False) |
|
for i in range(len(output)): |
|
o = output[i] |
|
assert 'adv' in o.keys() |
|
for k, v in o.items(): |
|
if k == 'adv': |
|
assert isinstance(v, torch.Tensor) |
|
assert v.shape == (1, ) |
|
else: |
|
if k == 'done': |
|
assert v == transitions[i][k] |
|
else: |
|
assert (v == transitions[i][k]).all() |
|
output1 = get_gae_with_default_last_value( |
|
copy.deepcopy(transitions), True, gamma=0.99, gae_lambda=0.97, cuda=False |
|
) |
|
for i in range(len(output)): |
|
assert output[i]['adv'].ne(output1[i]['adv']) |
|
|
|
data = copy.deepcopy(transitions) |
|
data.append({'value': last_value}) |
|
output2 = get_gae_with_default_last_value(data, False, gamma=0.99, gae_lambda=0.97, cuda=False) |
|
for i in range(len(output)): |
|
assert output[i]['adv'].eq(output2[i]['adv']) |
|
|
|
def test_get_gae_multi_agent(self): |
|
transitions = deque([self.get_transition_multi_agent() for _ in range(10)]) |
|
last_value = torch.randn(1, 8) |
|
output = get_gae(transitions, last_value, gamma=0.99, gae_lambda=0.97, cuda=False) |
|
for i in range(len(output)): |
|
o = output[i] |
|
assert 'adv' in o.keys() |
|
for k, v in o.items(): |
|
if k == 'adv': |
|
assert isinstance(v, torch.Tensor) |
|
assert v.shape == ( |
|
1, |
|
8, |
|
) |
|
else: |
|
if k == 'done': |
|
assert v == transitions[i][k] |
|
else: |
|
assert (v == transitions[i][k]).all() |
|
output1 = get_gae_with_default_last_value( |
|
copy.deepcopy(transitions), True, gamma=0.99, gae_lambda=0.97, cuda=False |
|
) |
|
for i in range(len(output)): |
|
for j in range(output[i]['adv'].shape[1]): |
|
assert output[i]['adv'][0][j].ne(output1[i]['adv'][0][j]) |
|
|
|
data = copy.deepcopy(transitions) |
|
data.append({'value': last_value}) |
|
output2 = get_gae_with_default_last_value(data, False, gamma=0.99, gae_lambda=0.97, cuda=False) |
|
for i in range(len(output)): |
|
for j in range(output[i]['adv'].shape[1]): |
|
assert output[i]['adv'][0][j].eq(output2[i]['adv'][0][j]) |
|
|
|
def test_get_nstep_return_data(self): |
|
nstep = 3 |
|
data = deque([self.get_transition() for _ in range(10)]) |
|
output_data = get_nstep_return_data(data, nstep=nstep) |
|
assert len(output_data) == 10 |
|
for i, o in enumerate(output_data): |
|
assert o['reward'].shape == (nstep, ) |
|
if i >= 10 - nstep + 1: |
|
assert o['done'] is data[-1]['done'] |
|
assert o['reward'][-(i - 10 + nstep):].sum() == 0 |
|
|
|
data = deque([self.get_transition() for _ in range(12)]) |
|
output_data = get_nstep_return_data(data, nstep=nstep) |
|
assert len(output_data) == 12 |
|
|
|
def test_get_train_sample(self): |
|
data = [self.get_transition() for _ in range(10)] |
|
output = get_train_sample(data, unroll_len=1, last_fn_type='drop') |
|
assert len(output) == 10 |
|
|
|
output = get_train_sample(data, unroll_len=4, last_fn_type='drop') |
|
assert len(output) == 2 |
|
for o in output: |
|
for v in o.values(): |
|
assert len(v) == 4 |
|
|
|
output = get_train_sample(data, unroll_len=4, last_fn_type='null_padding') |
|
assert len(output) == 3 |
|
for o in output: |
|
for v in o.values(): |
|
assert len(v) == 4 |
|
assert output[-1]['done'] == [False, False, True, True] |
|
for i in range(1, 10 % 4 + 1): |
|
assert id(output[-1]['obs'][-i]) != id(output[-1]['obs'][0]) |
|
|
|
output = get_train_sample(data, unroll_len=4, last_fn_type='last') |
|
assert len(output) == 3 |
|
for o in output: |
|
for v in o.values(): |
|
assert len(v) == 4 |
|
miss_num = 4 - 10 % 4 |
|
for i in range(10 % 4): |
|
assert id(output[-1]['obs'][i]) != id(output[-2]['obs'][miss_num + i]) |
|
|
|
output = get_train_sample(data, unroll_len=11, last_fn_type='last') |
|
assert len(output) == 1 |
|
assert len(output[0]['obs']) == 11 |
|
assert output[-1]['done'][-1] is True |
|
assert output[-1]['done'][0] is False |
|
assert id(output[-1]['obs'][-1]) != id(output[-1]['obs'][0]) |
|
|
|
|
|
test = TestAdder() |
|
test.test_get_gae_multi_agent() |
|
|