|
import copy |
|
from copy import deepcopy |
|
from collections import OrderedDict |
|
|
|
import pytest |
|
import torch |
|
import torch.nn as nn |
|
from ditk import logging |
|
|
|
from ding.torch_utils import get_lstm |
|
from ding.torch_utils.network.gtrxl import GTrXL |
|
from ding.model import model_wrap, register_wrapper, IModelWrapper |
|
from ding.model.wrapper.model_wrappers import BaseModelWrapper |
|
|
|
|
|
class TempMLP(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super(TempMLP, self).__init__() |
|
self.fc1 = nn.Linear(3, 4) |
|
self.bn1 = nn.BatchNorm1d(4) |
|
self.fc2 = nn.Linear(4, 6) |
|
self.act = nn.ReLU() |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.bn1(x) |
|
x = self.act(x) |
|
x = self.fc2(x) |
|
x = self.act(x) |
|
return x |
|
|
|
|
|
class ActorMLP(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super(ActorMLP, self).__init__() |
|
self.fc1 = nn.Linear(3, 4) |
|
self.bn1 = nn.BatchNorm1d(4) |
|
self.fc2 = nn.Linear(4, 6) |
|
self.act = nn.ReLU() |
|
self.out = nn.Softmax(dim=-1) |
|
|
|
def forward(self, inputs, tmp=0): |
|
x = self.fc1(inputs['obs']) |
|
x = self.bn1(x) |
|
x = self.act(x) |
|
x = self.fc2(x) |
|
x = self.act(x) |
|
x = self.out(x) |
|
ret = {'logit': x, 'tmp': tmp, 'action': x + torch.rand_like(x)} |
|
if 'mask' in inputs: |
|
ret['action_mask'] = inputs['mask'] |
|
return ret |
|
|
|
|
|
class HybridActorMLP(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super(HybridActorMLP, self).__init__() |
|
self.fc1 = nn.Linear(3, 4) |
|
self.bn1 = nn.BatchNorm1d(4) |
|
self.fc2 = nn.Linear(4, 6) |
|
self.act = nn.ReLU() |
|
self.out = nn.Softmax(dim=-1) |
|
|
|
self.fc2_cont = nn.Linear(4, 6) |
|
self.act_cont = nn.ReLU() |
|
|
|
def forward(self, inputs, tmp=0): |
|
x = self.fc1(inputs['obs']) |
|
x = self.bn1(x) |
|
x_ = self.act(x) |
|
|
|
x = self.fc2(x_) |
|
x = self.act(x) |
|
x_disc = self.out(x) |
|
|
|
x = self.fc2_cont(x_) |
|
x_cont = self.act_cont(x) |
|
|
|
ret = {'logit': x_disc, 'action_args': x_cont, 'tmp': tmp} |
|
|
|
if 'mask' in inputs: |
|
ret['action_mask'] = inputs['mask'] |
|
return ret |
|
|
|
|
|
class HybridReparamActorMLP(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super(HybridReparamActorMLP, self).__init__() |
|
self.fc1 = nn.Linear(3, 4) |
|
self.bn1 = nn.BatchNorm1d(4) |
|
self.fc2 = nn.Linear(4, 6) |
|
self.act = nn.ReLU() |
|
self.out = nn.Softmax(dim=-1) |
|
|
|
self.fc2_cont_mu = nn.Linear(4, 6) |
|
self.act_cont_mu = nn.ReLU() |
|
|
|
self.fc2_cont_sigma = nn.Linear(4, 6) |
|
self.act_cont_sigma = nn.ReLU() |
|
|
|
def forward(self, inputs, tmp=0): |
|
x = self.fc1(inputs['obs']) |
|
x = self.bn1(x) |
|
x_ = self.act(x) |
|
|
|
x = self.fc2(x_) |
|
x = self.act(x) |
|
x_disc = self.out(x) |
|
|
|
x = self.fc2_cont_mu(x_) |
|
x_cont_mu = self.act_cont_mu(x) |
|
|
|
x = self.fc2_cont_sigma(x_) |
|
x_cont_sigma = self.act_cont_sigma(x) + 1e-8 |
|
|
|
ret = {'logit': {'action_type': x_disc, 'action_args': {'mu': x_cont_mu, 'sigma': x_cont_sigma}}, 'tmp': tmp} |
|
|
|
if 'mask' in inputs: |
|
ret['action_mask'] = inputs['mask'] |
|
return ret |
|
|
|
|
|
class ReparamActorMLP(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super(ReparamActorMLP, self).__init__() |
|
self.fc1 = nn.Linear(3, 4) |
|
self.bn1 = nn.BatchNorm1d(4) |
|
self.fc2 = nn.Linear(4, 6) |
|
self.act = nn.ReLU() |
|
|
|
self.fc2_cont_mu = nn.Linear(4, 6) |
|
self.fc2_cont_sigma = nn.Linear(4, 6) |
|
|
|
def forward(self, inputs, tmp=0): |
|
x = self.fc1(inputs['obs']) |
|
x = self.bn1(x) |
|
x_ = self.act(x) |
|
|
|
x = self.fc2_cont_mu(x_) |
|
x_cont_mu = self.act(x) |
|
|
|
x = self.fc2_cont_sigma(x_) |
|
x_cont_sigma = self.act(x) + 1e-8 |
|
|
|
ret = {'logit': {'mu': x_cont_mu, 'sigma': x_cont_sigma}, 'tmp': tmp} |
|
|
|
if 'mask' in inputs: |
|
ret['action_mask'] = inputs['mask'] |
|
return ret |
|
|
|
|
|
class DeterministicActorMLP(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super(DeterministicActorMLP, self).__init__() |
|
self.fc1 = nn.Linear(3, 4) |
|
self.bn1 = nn.BatchNorm1d(4) |
|
self.act = nn.ReLU() |
|
|
|
self.fc2_cont_mu = nn.Linear(4, 6) |
|
self.act_cont_mu = nn.ReLU() |
|
|
|
def forward(self, inputs): |
|
x = self.fc1(inputs['obs']) |
|
x = self.bn1(x) |
|
x_ = self.act(x) |
|
|
|
x = self.fc2_cont_mu(x_) |
|
x_cont_mu = self.act_cont_mu(x) |
|
|
|
ret = { |
|
'logit': { |
|
'mu': x_cont_mu, |
|
} |
|
} |
|
|
|
if 'mask' in inputs: |
|
ret['action_mask'] = inputs['mask'] |
|
return ret |
|
|
|
|
|
class TempLSTM(torch.nn.Module): |
|
|
|
def __init__(self): |
|
super(TempLSTM, self).__init__() |
|
self.model = get_lstm(lstm_type='pytorch', input_size=36, hidden_size=32, num_layers=2, norm_type=None) |
|
|
|
def forward(self, data): |
|
output, next_state = self.model(data['f'], data['prev_state'], list_next_state=True) |
|
return {'output': output, 'next_state': next_state} |
|
|
|
|
|
@pytest.fixture(scope='function') |
|
def setup_model(): |
|
return torch.nn.Linear(3, 6) |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestModelWrappers: |
|
|
|
def test_hidden_state_wrapper(self): |
|
|
|
model = TempLSTM() |
|
state_num = 4 |
|
model = model_wrap(model, wrapper_name='hidden_state', state_num=state_num, save_prev_state=True) |
|
model.reset() |
|
data = {'f': torch.randn(2, 4, 36)} |
|
output = model.forward(data) |
|
assert output['output'].shape == (2, state_num, 32) |
|
assert len(output['prev_state']) == 4 |
|
assert output['prev_state'][0]['h'].shape == (2, 1, 32) |
|
for item in model._state.values(): |
|
assert isinstance(item, dict) and len(item) == 2 |
|
assert all(t.shape == (2, 1, 32) for t in item.values()) |
|
|
|
data = {'f': torch.randn(2, 3, 36)} |
|
data_id = [0, 1, 3] |
|
output = model.forward(data, data_id=data_id) |
|
assert output['output'].shape == (2, 3, 32) |
|
assert all([len(s) == 2 for s in output['prev_state']]) |
|
for item in model._state.values(): |
|
assert isinstance(item, dict) and len(item) == 2 |
|
assert all(t.shape == (2, 1, 32) for t in item.values()) |
|
|
|
data = {'f': torch.randn(2, 2, 36)} |
|
data_id = [0, 1] |
|
output = model.forward(data, data_id=data_id) |
|
assert output['output'].shape == (2, 2, 32) |
|
|
|
assert all([isinstance(s, dict) and len(s) == 2 for s in model._state.values()]) |
|
model.reset() |
|
assert all([isinstance(s, type(None)) for s in model._state.values()]) |
|
|
|
def test_target_network_wrapper(self): |
|
|
|
model = TempMLP() |
|
target_model = deepcopy(model) |
|
target_model2 = deepcopy(model) |
|
target_model = model_wrap(target_model, wrapper_name='target', update_type='assign', update_kwargs={'freq': 2}) |
|
model = model_wrap(model, wrapper_name='base') |
|
register_wrapper('abstract', IModelWrapper) |
|
assert all([hasattr(target_model, n) for n in ['reset', 'forward', 'update']]) |
|
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12 |
|
model.fc1.weight.data = torch.randn_like(model.fc1.weight) |
|
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12 |
|
target_model.update(model.state_dict(), direct=True) |
|
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12 |
|
model.reset() |
|
target_model.reset() |
|
|
|
inputs = torch.randn(2, 3) |
|
model.train() |
|
target_model.train() |
|
output = model.forward(inputs) |
|
with torch.no_grad(): |
|
output_target = target_model.forward(inputs) |
|
assert output.eq(output_target).sum() == 2 * 6 |
|
model.fc1.weight.data = torch.randn_like(model.fc1.weight) |
|
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12 |
|
target_model.update(model.state_dict()) |
|
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12 |
|
target_model.update(model.state_dict()) |
|
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12 |
|
|
|
assert target_model._update_count != 0 |
|
target_model.reset() |
|
assert target_model._update_count != 0 |
|
target_model.reset(target_update_count=0) |
|
assert target_model._update_count == 0 |
|
|
|
target_model2 = model_wrap( |
|
target_model2, wrapper_name='target', update_type='momentum', update_kwargs={'theta': 0.01} |
|
) |
|
target_model2.update(model.state_dict(), direct=True) |
|
assert model.fc1.weight.eq(target_model2.fc1.weight).sum() == 12 |
|
model.fc1.weight.data = torch.randn_like(model.fc1.weight) |
|
old_state_dict = target_model2.state_dict() |
|
target_model2.update(model.state_dict()) |
|
assert target_model2.fc1.weight.data.eq( |
|
old_state_dict['fc1.weight'] * (1 - 0.01) + model.fc1.weight.data * 0.01 |
|
).all() |
|
|
|
def test_eps_greedy_wrapper(self): |
|
model = ActorMLP() |
|
model = model_wrap(model, wrapper_name='eps_greedy_sample') |
|
model.eval() |
|
eps_threshold = 0.5 |
|
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} |
|
with torch.no_grad(): |
|
output = model.forward(data, eps=eps_threshold) |
|
assert output['tmp'] == 0 |
|
for i in range(10): |
|
if i == 5: |
|
data.pop('mask') |
|
with torch.no_grad(): |
|
output = model.forward(data, eps=eps_threshold, tmp=1) |
|
assert isinstance(output, dict) |
|
assert output['tmp'] == 1 |
|
|
|
def test_multinomial_sample_wrapper(self): |
|
model = model_wrap(ActorMLP(), wrapper_name='multinomial_sample') |
|
data = {'obs': torch.randn(4, 3)} |
|
output = model.forward(data) |
|
assert output['action'].shape == (4, ) |
|
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} |
|
output = model.forward(data) |
|
assert output['action'].shape == (4, ) |
|
|
|
def test_eps_greedy_multinomial_wrapper(self): |
|
model = ActorMLP() |
|
model = model_wrap(model, wrapper_name='eps_greedy_multinomial_sample') |
|
model.eval() |
|
eps_threshold = 0.5 |
|
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} |
|
with torch.no_grad(): |
|
output = model.forward(data, eps=eps_threshold, alpha=0.2) |
|
assert output['tmp'] == 0 |
|
for i in range(10): |
|
if i == 5: |
|
data.pop('mask') |
|
with torch.no_grad(): |
|
output = model.forward(data, eps=eps_threshold, tmp=1, alpha=0.2) |
|
assert isinstance(output, dict) |
|
assert output['tmp'] == 1 |
|
|
|
def test_hybrid_eps_greedy_wrapper(self): |
|
model = HybridActorMLP() |
|
model = model_wrap(model, wrapper_name='hybrid_eps_greedy_sample') |
|
model.eval() |
|
eps_threshold = 0.5 |
|
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} |
|
with torch.no_grad(): |
|
output = model.forward(data, eps=eps_threshold) |
|
|
|
|
|
assert isinstance(output['action']['action_args'], |
|
torch.Tensor) and output['action']['action_args'].shape == (4, 6) |
|
for i in range(10): |
|
if i == 5: |
|
data.pop('mask') |
|
with torch.no_grad(): |
|
output = model.forward(data, eps=eps_threshold, tmp=1) |
|
assert isinstance(output, dict) |
|
|
|
def test_hybrid_eps_greedy_multinomial_wrapper(self): |
|
model = HybridActorMLP() |
|
model = model_wrap(model, wrapper_name='hybrid_eps_greedy_multinomial_sample') |
|
model.eval() |
|
eps_threshold = 0.5 |
|
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} |
|
with torch.no_grad(): |
|
output = model.forward(data, eps=eps_threshold) |
|
assert isinstance(output['logit'], torch.Tensor) and output['logit'].shape == (4, 6) |
|
assert isinstance(output['action']['action_type'], |
|
torch.Tensor) and output['action']['action_type'].shape == (4, ) |
|
assert isinstance(output['action']['action_args'], |
|
torch.Tensor) and output['action']['action_args'].shape == (4, 6) |
|
for i in range(10): |
|
if i == 5: |
|
data.pop('mask') |
|
with torch.no_grad(): |
|
output = model.forward(data, eps=eps_threshold, tmp=1) |
|
assert isinstance(output, dict) |
|
|
|
def test_hybrid_reparam_multinomial_wrapper(self): |
|
model = HybridReparamActorMLP() |
|
model = model_wrap(model, wrapper_name='hybrid_reparam_multinomial_sample') |
|
model.eval() |
|
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} |
|
with torch.no_grad(): |
|
output = model.forward(data) |
|
assert isinstance(output['logit'], dict) and output['logit']['action_type'].shape == (4, 6) |
|
assert isinstance(output['logit']['action_args'], dict) and output['logit']['action_args']['mu'].shape == ( |
|
4, 6 |
|
) and output['logit']['action_args']['sigma'].shape == (4, 6) |
|
assert isinstance(output['action']['action_type'], |
|
torch.Tensor) and output['action']['action_type'].shape == (4, ) |
|
assert isinstance(output['action']['action_args'], |
|
torch.Tensor) and output['action']['action_args'].shape == (4, 6) |
|
for i in range(10): |
|
if i == 5: |
|
data.pop('mask') |
|
with torch.no_grad(): |
|
output = model.forward(data, tmp=1) |
|
assert isinstance(output, dict) |
|
|
|
def test_argmax_sample_wrapper(self): |
|
model = model_wrap(ActorMLP(), wrapper_name='argmax_sample') |
|
data = {'obs': torch.randn(4, 3)} |
|
output = model.forward(data) |
|
logit = output['logit'] |
|
assert output['action'].eq(logit.argmax(dim=-1)).all() |
|
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} |
|
output = model.forward(data) |
|
logit = output['logit'].sub(1e8 * (1 - data['mask'])) |
|
assert output['action'].eq(logit.argmax(dim=-1)).all() |
|
|
|
def test_hybrid_argmax_sample_wrapper(self): |
|
model = model_wrap(HybridActorMLP(), wrapper_name='hybrid_argmax_sample') |
|
data = {'obs': torch.randn(4, 3)} |
|
output = model.forward(data) |
|
logit = output['logit'] |
|
assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all() |
|
assert isinstance(output['action']['action_args'], |
|
torch.Tensor) and output['action']['action_args'].shape == (4, 6) |
|
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} |
|
output = model.forward(data) |
|
logit = output['logit'].sub(1e8 * (1 - data['mask'])) |
|
assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all() |
|
assert output['action']['action_args'].shape == (4, 6) |
|
|
|
def test_hybrid_deterministic_argmax_sample_wrapper(self): |
|
model = model_wrap(HybridReparamActorMLP(), wrapper_name='hybrid_deterministic_argmax_sample') |
|
data = {'obs': torch.randn(4, 3)} |
|
output = model.forward(data) |
|
assert output['action']['action_type'].eq(output['logit']['action_type'].argmax(dim=-1)).all() |
|
assert isinstance(output['action']['action_args'], |
|
torch.Tensor) and output['action']['action_args'].shape == (4, 6) |
|
assert output['action']['action_args'].eq(output['logit']['action_args']['mu']).all |
|
|
|
def test_deterministic_sample_wrapper(self): |
|
model = model_wrap(DeterministicActorMLP(), wrapper_name='deterministic_sample') |
|
data = {'obs': torch.randn(4, 3)} |
|
output = model.forward(data) |
|
assert output['action'].eq(output['logit']['mu']).all() |
|
assert isinstance(output['action'], torch.Tensor) and output['action'].shape == (4, 6) |
|
|
|
def test_reparam_wrapper(self): |
|
model = ReparamActorMLP() |
|
model = model_wrap(model, wrapper_name='reparam_sample') |
|
model.eval() |
|
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} |
|
with torch.no_grad(): |
|
output = model.forward(data) |
|
assert isinstance(output['logit'], |
|
dict) and output['logit']['mu'].shape == (4, 6) and output['logit']['sigma'].shape == (4, 6) |
|
for i in range(10): |
|
if i == 5: |
|
data.pop('mask') |
|
with torch.no_grad(): |
|
output = model.forward(data, tmp=1) |
|
assert isinstance(output, dict) |
|
|
|
def test_eps_greedy_wrapper_with_list_eps(self): |
|
model = ActorMLP() |
|
model = model_wrap(model, wrapper_name='eps_greedy_sample') |
|
model.eval() |
|
eps_threshold = {i: 0.5 for i in range(4)} |
|
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))} |
|
with torch.no_grad(): |
|
output = model.forward(data, eps=eps_threshold) |
|
assert output['tmp'] == 0 |
|
for i in range(10): |
|
if i == 5: |
|
data.pop('mask') |
|
with torch.no_grad(): |
|
output = model.forward(data, eps=eps_threshold, tmp=1) |
|
assert isinstance(output, dict) |
|
assert output['tmp'] == 1 |
|
|
|
def test_action_noise_wrapper(self): |
|
model = model_wrap( |
|
ActorMLP(), |
|
wrapper_name='action_noise', |
|
noise_type='gauss', |
|
noise_range={ |
|
'min': -0.1, |
|
'max': 0.1 |
|
}, |
|
action_range={ |
|
'min': -0.05, |
|
'max': 0.05 |
|
} |
|
) |
|
data = {'obs': torch.randn(4, 3)} |
|
output = model.forward(data) |
|
action = output['action'] |
|
assert action.shape == (4, 6) |
|
assert action.eq(action.clamp(-0.05, 0.05)).all() |
|
|
|
def test_transformer_input_wrapper(self): |
|
seq_len, bs, obs_shape = 8, 8, 32 |
|
emb_dim = 64 |
|
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim) |
|
model = model_wrap(model, wrapper_name='transformer_input', seq_len=seq_len) |
|
obs = [] |
|
for i in range(seq_len + 1): |
|
obs.append(torch.randn((bs, obs_shape))) |
|
out = model.forward(obs[0], only_last_logit=False) |
|
assert out['logit'].shape == (seq_len, bs, emb_dim) |
|
assert out['input_seq'].shape == (seq_len, bs, obs_shape) |
|
assert sum(out['input_seq'][1:].flatten()) == 0 |
|
for i in range(1, seq_len - 1): |
|
out = model.forward(obs[i]) |
|
assert out['logit'].shape == (bs, emb_dim) |
|
assert out['input_seq'].shape == (seq_len, bs, obs_shape) |
|
assert sum(out['input_seq'][seq_len - 1:].flatten()) == 0 |
|
assert sum(out['input_seq'][:seq_len - 1].flatten()) != 0 |
|
out = model.forward(obs[seq_len - 1]) |
|
prev_memory = torch.clone(out['input_seq']) |
|
out = model.forward(obs[seq_len]) |
|
assert torch.all(torch.eq(out['input_seq'][seq_len - 2], prev_memory[seq_len - 1])) |
|
|
|
model.reset(data_id=[0, 5]) |
|
assert sum(model.obs_memory[:, 0].flatten()) == 0 and sum(model.obs_memory[:, 5].flatten()) == 0 |
|
assert sum(model.obs_memory[:, 1].flatten()) != 0 |
|
assert model.memory_idx[0] == 0 and model.memory_idx[5] == 0 and model.memory_idx[1] == seq_len |
|
|
|
model.reset() |
|
assert model.obs_memory is None |
|
|
|
def test_transformer_segment_wrapper(self): |
|
seq_len, bs, obs_shape = 12, 8, 32 |
|
layer_num, memory_len, emb_dim = 3, 4, 4 |
|
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num) |
|
model = model_wrap(model, wrapper_name='transformer_segment', seq_len=seq_len) |
|
inputs1 = torch.randn((seq_len, bs, obs_shape)) |
|
out = model.forward(inputs1) |
|
info = model.info('info') |
|
info = model.info('x') |
|
|
|
def test_transformer_memory_wrapper(self): |
|
seq_len, bs, obs_shape = 12, 8, 32 |
|
layer_num, memory_len, emb_dim = 3, 4, 4 |
|
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num) |
|
model1 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs) |
|
model2 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs) |
|
model1.show_memory_occupancy() |
|
inputs1 = torch.randn((seq_len, bs, obs_shape)) |
|
out = model1.forward(inputs1) |
|
new_memory1 = model1.memory |
|
inputs2 = torch.randn((seq_len, bs, obs_shape)) |
|
out = model2.forward(inputs2) |
|
new_memory2 = model2.memory |
|
assert not torch.all(torch.eq(new_memory1, new_memory2)) |
|
model1.reset(data_id=[0, 5]) |
|
assert sum(model1.memory[:, :, 0].flatten()) == 0 and sum(model1.memory[:, :, 5].flatten()) == 0 |
|
assert sum(model1.memory[:, :, 1].flatten()) != 0 |
|
model1.reset() |
|
assert sum(model1.memory.flatten()) == 0 |
|
|
|
seq_len, bs, obs_shape = 8, 8, 32 |
|
layer_num, memory_len, emb_dim = 3, 20, 4 |
|
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num) |
|
model = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs) |
|
inputs1 = torch.randn((seq_len, bs, obs_shape)) |
|
out = model.forward(inputs1) |
|
new_memory1 = model.memory |
|
inputs2 = torch.randn((seq_len, bs, obs_shape)) |
|
out = model.forward(inputs2) |
|
new_memory2 = model.memory |
|
print(new_memory1.shape, inputs1.shape) |
|
assert sum(new_memory1[:, -8:].flatten()) != 0 |
|
assert sum(new_memory1[:, :-8].flatten()) == 0 |
|
assert sum(new_memory2[:, -16:].flatten()) != 0 |
|
assert sum(new_memory2[:, :-16].flatten()) == 0 |
|
assert torch.all(torch.eq(new_memory1[:, -8:], new_memory2[:, -16:-8])) |
|
|
|
def test_combination_argmax_sample_wrapper(self): |
|
model = model_wrap(ActorMLP(), wrapper_name='combination_argmax_sample') |
|
data = {'obs': torch.randn(4, 3)} |
|
shot_number = 2 |
|
output = model.forward(shot_number=shot_number, inputs=data) |
|
assert output['action'].shape == (4, shot_number) |
|
assert (output['action'] >= 0).all() and (output['action'] < 64).all() |
|
|
|
def test_combination_multinomial_sample_wrapper(self): |
|
model = model_wrap(ActorMLP(), wrapper_name='combination_multinomial_sample') |
|
data = {'obs': torch.randn(4, 3)} |
|
shot_number = 2 |
|
output = model.forward(shot_number=shot_number, inputs=data) |
|
assert output['action'].shape == (4, shot_number) |
|
assert (output['action'] >= 0).all() and (output['action'] < 64).all() |
|
|