gomoku / DI-engine /ding /model /wrapper /test_model_wrappers.py
zjowowen's picture
init space
079c32c
raw
history blame
23.1 kB
import copy
from copy import deepcopy
from collections import OrderedDict
import pytest
import torch
import torch.nn as nn
from ditk import logging
from ding.torch_utils import get_lstm
from ding.torch_utils.network.gtrxl import GTrXL
from ding.model import model_wrap, register_wrapper, IModelWrapper
from ding.model.wrapper.model_wrappers import BaseModelWrapper
class TempMLP(torch.nn.Module):
def __init__(self):
super(TempMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.fc2 = nn.Linear(4, 6)
self.act = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x)
x = self.act(x)
x = self.fc2(x)
x = self.act(x)
return x
class ActorMLP(torch.nn.Module):
def __init__(self):
super(ActorMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.fc2 = nn.Linear(4, 6)
self.act = nn.ReLU()
self.out = nn.Softmax(dim=-1)
def forward(self, inputs, tmp=0):
x = self.fc1(inputs['obs'])
x = self.bn1(x)
x = self.act(x)
x = self.fc2(x)
x = self.act(x)
x = self.out(x)
ret = {'logit': x, 'tmp': tmp, 'action': x + torch.rand_like(x)}
if 'mask' in inputs:
ret['action_mask'] = inputs['mask']
return ret
class HybridActorMLP(torch.nn.Module):
def __init__(self):
super(HybridActorMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.fc2 = nn.Linear(4, 6)
self.act = nn.ReLU()
self.out = nn.Softmax(dim=-1)
self.fc2_cont = nn.Linear(4, 6)
self.act_cont = nn.ReLU()
def forward(self, inputs, tmp=0):
x = self.fc1(inputs['obs'])
x = self.bn1(x)
x_ = self.act(x)
x = self.fc2(x_)
x = self.act(x)
x_disc = self.out(x)
x = self.fc2_cont(x_)
x_cont = self.act_cont(x)
ret = {'logit': x_disc, 'action_args': x_cont, 'tmp': tmp}
if 'mask' in inputs:
ret['action_mask'] = inputs['mask']
return ret
class HybridReparamActorMLP(torch.nn.Module):
def __init__(self):
super(HybridReparamActorMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.fc2 = nn.Linear(4, 6)
self.act = nn.ReLU()
self.out = nn.Softmax(dim=-1)
self.fc2_cont_mu = nn.Linear(4, 6)
self.act_cont_mu = nn.ReLU()
self.fc2_cont_sigma = nn.Linear(4, 6)
self.act_cont_sigma = nn.ReLU()
def forward(self, inputs, tmp=0):
x = self.fc1(inputs['obs'])
x = self.bn1(x)
x_ = self.act(x)
x = self.fc2(x_)
x = self.act(x)
x_disc = self.out(x)
x = self.fc2_cont_mu(x_)
x_cont_mu = self.act_cont_mu(x)
x = self.fc2_cont_sigma(x_)
x_cont_sigma = self.act_cont_sigma(x) + 1e-8
ret = {'logit': {'action_type': x_disc, 'action_args': {'mu': x_cont_mu, 'sigma': x_cont_sigma}}, 'tmp': tmp}
if 'mask' in inputs:
ret['action_mask'] = inputs['mask']
return ret
class ReparamActorMLP(torch.nn.Module):
def __init__(self):
super(ReparamActorMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.fc2 = nn.Linear(4, 6)
self.act = nn.ReLU()
self.fc2_cont_mu = nn.Linear(4, 6)
self.fc2_cont_sigma = nn.Linear(4, 6)
def forward(self, inputs, tmp=0):
x = self.fc1(inputs['obs'])
x = self.bn1(x)
x_ = self.act(x)
x = self.fc2_cont_mu(x_)
x_cont_mu = self.act(x)
x = self.fc2_cont_sigma(x_)
x_cont_sigma = self.act(x) + 1e-8
ret = {'logit': {'mu': x_cont_mu, 'sigma': x_cont_sigma}, 'tmp': tmp}
if 'mask' in inputs:
ret['action_mask'] = inputs['mask']
return ret
class DeterministicActorMLP(torch.nn.Module):
def __init__(self):
super(DeterministicActorMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.act = nn.ReLU()
self.fc2_cont_mu = nn.Linear(4, 6)
self.act_cont_mu = nn.ReLU()
def forward(self, inputs):
x = self.fc1(inputs['obs'])
x = self.bn1(x)
x_ = self.act(x)
x = self.fc2_cont_mu(x_)
x_cont_mu = self.act_cont_mu(x)
ret = {
'logit': {
'mu': x_cont_mu,
}
}
if 'mask' in inputs:
ret['action_mask'] = inputs['mask']
return ret
class TempLSTM(torch.nn.Module):
def __init__(self):
super(TempLSTM, self).__init__()
self.model = get_lstm(lstm_type='pytorch', input_size=36, hidden_size=32, num_layers=2, norm_type=None)
def forward(self, data):
output, next_state = self.model(data['f'], data['prev_state'], list_next_state=True)
return {'output': output, 'next_state': next_state}
@pytest.fixture(scope='function')
def setup_model():
return torch.nn.Linear(3, 6)
@pytest.mark.unittest
class TestModelWrappers:
def test_hidden_state_wrapper(self):
model = TempLSTM()
state_num = 4
model = model_wrap(model, wrapper_name='hidden_state', state_num=state_num, save_prev_state=True)
model.reset()
data = {'f': torch.randn(2, 4, 36)}
output = model.forward(data)
assert output['output'].shape == (2, state_num, 32)
assert len(output['prev_state']) == 4
assert output['prev_state'][0]['h'].shape == (2, 1, 32)
for item in model._state.values():
assert isinstance(item, dict) and len(item) == 2
assert all(t.shape == (2, 1, 32) for t in item.values())
data = {'f': torch.randn(2, 3, 36)}
data_id = [0, 1, 3]
output = model.forward(data, data_id=data_id)
assert output['output'].shape == (2, 3, 32)
assert all([len(s) == 2 for s in output['prev_state']])
for item in model._state.values():
assert isinstance(item, dict) and len(item) == 2
assert all(t.shape == (2, 1, 32) for t in item.values())
data = {'f': torch.randn(2, 2, 36)}
data_id = [0, 1]
output = model.forward(data, data_id=data_id)
assert output['output'].shape == (2, 2, 32)
assert all([isinstance(s, dict) and len(s) == 2 for s in model._state.values()])
model.reset()
assert all([isinstance(s, type(None)) for s in model._state.values()])
def test_target_network_wrapper(self):
model = TempMLP()
target_model = deepcopy(model)
target_model2 = deepcopy(model)
target_model = model_wrap(target_model, wrapper_name='target', update_type='assign', update_kwargs={'freq': 2})
model = model_wrap(model, wrapper_name='base')
register_wrapper('abstract', IModelWrapper)
assert all([hasattr(target_model, n) for n in ['reset', 'forward', 'update']])
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12
model.fc1.weight.data = torch.randn_like(model.fc1.weight)
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12
target_model.update(model.state_dict(), direct=True)
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12
model.reset()
target_model.reset()
inputs = torch.randn(2, 3)
model.train()
target_model.train()
output = model.forward(inputs)
with torch.no_grad():
output_target = target_model.forward(inputs)
assert output.eq(output_target).sum() == 2 * 6
model.fc1.weight.data = torch.randn_like(model.fc1.weight)
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12
target_model.update(model.state_dict())
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12
target_model.update(model.state_dict())
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12
# test real reset update_count
assert target_model._update_count != 0
target_model.reset()
assert target_model._update_count != 0
target_model.reset(target_update_count=0)
assert target_model._update_count == 0
target_model2 = model_wrap(
target_model2, wrapper_name='target', update_type='momentum', update_kwargs={'theta': 0.01}
)
target_model2.update(model.state_dict(), direct=True)
assert model.fc1.weight.eq(target_model2.fc1.weight).sum() == 12
model.fc1.weight.data = torch.randn_like(model.fc1.weight)
old_state_dict = target_model2.state_dict()
target_model2.update(model.state_dict())
assert target_model2.fc1.weight.data.eq(
old_state_dict['fc1.weight'] * (1 - 0.01) + model.fc1.weight.data * 0.01
).all()
def test_eps_greedy_wrapper(self):
model = ActorMLP()
model = model_wrap(model, wrapper_name='eps_greedy_sample')
model.eval()
eps_threshold = 0.5
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data, eps=eps_threshold)
assert output['tmp'] == 0
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, tmp=1)
assert isinstance(output, dict)
assert output['tmp'] == 1
def test_multinomial_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='multinomial_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
assert output['action'].shape == (4, )
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
output = model.forward(data)
assert output['action'].shape == (4, )
def test_eps_greedy_multinomial_wrapper(self):
model = ActorMLP()
model = model_wrap(model, wrapper_name='eps_greedy_multinomial_sample')
model.eval()
eps_threshold = 0.5
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, alpha=0.2)
assert output['tmp'] == 0
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, tmp=1, alpha=0.2)
assert isinstance(output, dict)
assert output['tmp'] == 1
def test_hybrid_eps_greedy_wrapper(self):
model = HybridActorMLP()
model = model_wrap(model, wrapper_name='hybrid_eps_greedy_sample')
model.eval()
eps_threshold = 0.5
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data, eps=eps_threshold)
# logit = output['logit']
# assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all()
assert isinstance(output['action']['action_args'],
torch.Tensor) and output['action']['action_args'].shape == (4, 6)
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, tmp=1)
assert isinstance(output, dict)
def test_hybrid_eps_greedy_multinomial_wrapper(self):
model = HybridActorMLP()
model = model_wrap(model, wrapper_name='hybrid_eps_greedy_multinomial_sample')
model.eval()
eps_threshold = 0.5
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data, eps=eps_threshold)
assert isinstance(output['logit'], torch.Tensor) and output['logit'].shape == (4, 6)
assert isinstance(output['action']['action_type'],
torch.Tensor) and output['action']['action_type'].shape == (4, )
assert isinstance(output['action']['action_args'],
torch.Tensor) and output['action']['action_args'].shape == (4, 6)
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, tmp=1)
assert isinstance(output, dict)
def test_hybrid_reparam_multinomial_wrapper(self):
model = HybridReparamActorMLP()
model = model_wrap(model, wrapper_name='hybrid_reparam_multinomial_sample')
model.eval()
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data)
assert isinstance(output['logit'], dict) and output['logit']['action_type'].shape == (4, 6)
assert isinstance(output['logit']['action_args'], dict) and output['logit']['action_args']['mu'].shape == (
4, 6
) and output['logit']['action_args']['sigma'].shape == (4, 6)
assert isinstance(output['action']['action_type'],
torch.Tensor) and output['action']['action_type'].shape == (4, )
assert isinstance(output['action']['action_args'],
torch.Tensor) and output['action']['action_args'].shape == (4, 6)
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, tmp=1)
assert isinstance(output, dict)
def test_argmax_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='argmax_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
logit = output['logit']
assert output['action'].eq(logit.argmax(dim=-1)).all()
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
output = model.forward(data)
logit = output['logit'].sub(1e8 * (1 - data['mask']))
assert output['action'].eq(logit.argmax(dim=-1)).all()
def test_hybrid_argmax_sample_wrapper(self):
model = model_wrap(HybridActorMLP(), wrapper_name='hybrid_argmax_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
logit = output['logit']
assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all()
assert isinstance(output['action']['action_args'],
torch.Tensor) and output['action']['action_args'].shape == (4, 6)
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
output = model.forward(data)
logit = output['logit'].sub(1e8 * (1 - data['mask']))
assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all()
assert output['action']['action_args'].shape == (4, 6)
def test_hybrid_deterministic_argmax_sample_wrapper(self):
model = model_wrap(HybridReparamActorMLP(), wrapper_name='hybrid_deterministic_argmax_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
assert output['action']['action_type'].eq(output['logit']['action_type'].argmax(dim=-1)).all()
assert isinstance(output['action']['action_args'],
torch.Tensor) and output['action']['action_args'].shape == (4, 6)
assert output['action']['action_args'].eq(output['logit']['action_args']['mu']).all
def test_deterministic_sample_wrapper(self):
model = model_wrap(DeterministicActorMLP(), wrapper_name='deterministic_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
assert output['action'].eq(output['logit']['mu']).all()
assert isinstance(output['action'], torch.Tensor) and output['action'].shape == (4, 6)
def test_reparam_wrapper(self):
model = ReparamActorMLP()
model = model_wrap(model, wrapper_name='reparam_sample')
model.eval()
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data)
assert isinstance(output['logit'],
dict) and output['logit']['mu'].shape == (4, 6) and output['logit']['sigma'].shape == (4, 6)
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, tmp=1)
assert isinstance(output, dict)
def test_eps_greedy_wrapper_with_list_eps(self):
model = ActorMLP()
model = model_wrap(model, wrapper_name='eps_greedy_sample')
model.eval()
eps_threshold = {i: 0.5 for i in range(4)} # for NGU
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data, eps=eps_threshold)
assert output['tmp'] == 0
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, tmp=1)
assert isinstance(output, dict)
assert output['tmp'] == 1
def test_action_noise_wrapper(self):
model = model_wrap(
ActorMLP(),
wrapper_name='action_noise',
noise_type='gauss',
noise_range={
'min': -0.1,
'max': 0.1
},
action_range={
'min': -0.05,
'max': 0.05
}
)
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
action = output['action']
assert action.shape == (4, 6)
assert action.eq(action.clamp(-0.05, 0.05)).all()
def test_transformer_input_wrapper(self):
seq_len, bs, obs_shape = 8, 8, 32
emb_dim = 64
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim)
model = model_wrap(model, wrapper_name='transformer_input', seq_len=seq_len)
obs = []
for i in range(seq_len + 1):
obs.append(torch.randn((bs, obs_shape)))
out = model.forward(obs[0], only_last_logit=False)
assert out['logit'].shape == (seq_len, bs, emb_dim)
assert out['input_seq'].shape == (seq_len, bs, obs_shape)
assert sum(out['input_seq'][1:].flatten()) == 0
for i in range(1, seq_len - 1):
out = model.forward(obs[i])
assert out['logit'].shape == (bs, emb_dim)
assert out['input_seq'].shape == (seq_len, bs, obs_shape)
assert sum(out['input_seq'][seq_len - 1:].flatten()) == 0
assert sum(out['input_seq'][:seq_len - 1].flatten()) != 0
out = model.forward(obs[seq_len - 1])
prev_memory = torch.clone(out['input_seq'])
out = model.forward(obs[seq_len])
assert torch.all(torch.eq(out['input_seq'][seq_len - 2], prev_memory[seq_len - 1]))
# test update of single batches in the memory
model.reset(data_id=[0, 5]) # reset memory batch in position 0 and 5
assert sum(model.obs_memory[:, 0].flatten()) == 0 and sum(model.obs_memory[:, 5].flatten()) == 0
assert sum(model.obs_memory[:, 1].flatten()) != 0
assert model.memory_idx[0] == 0 and model.memory_idx[5] == 0 and model.memory_idx[1] == seq_len
# test reset
model.reset()
assert model.obs_memory is None
def test_transformer_segment_wrapper(self):
seq_len, bs, obs_shape = 12, 8, 32
layer_num, memory_len, emb_dim = 3, 4, 4
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num)
model = model_wrap(model, wrapper_name='transformer_segment', seq_len=seq_len)
inputs1 = torch.randn((seq_len, bs, obs_shape))
out = model.forward(inputs1)
info = model.info('info')
info = model.info('x')
def test_transformer_memory_wrapper(self):
seq_len, bs, obs_shape = 12, 8, 32
layer_num, memory_len, emb_dim = 3, 4, 4
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num)
model1 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs)
model2 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs)
model1.show_memory_occupancy()
inputs1 = torch.randn((seq_len, bs, obs_shape))
out = model1.forward(inputs1)
new_memory1 = model1.memory
inputs2 = torch.randn((seq_len, bs, obs_shape))
out = model2.forward(inputs2)
new_memory2 = model2.memory
assert not torch.all(torch.eq(new_memory1, new_memory2))
model1.reset(data_id=[0, 5])
assert sum(model1.memory[:, :, 0].flatten()) == 0 and sum(model1.memory[:, :, 5].flatten()) == 0
assert sum(model1.memory[:, :, 1].flatten()) != 0
model1.reset()
assert sum(model1.memory.flatten()) == 0
seq_len, bs, obs_shape = 8, 8, 32
layer_num, memory_len, emb_dim = 3, 20, 4
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num)
model = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs)
inputs1 = torch.randn((seq_len, bs, obs_shape))
out = model.forward(inputs1)
new_memory1 = model.memory
inputs2 = torch.randn((seq_len, bs, obs_shape))
out = model.forward(inputs2)
new_memory2 = model.memory
print(new_memory1.shape, inputs1.shape)
assert sum(new_memory1[:, -8:].flatten()) != 0
assert sum(new_memory1[:, :-8].flatten()) == 0
assert sum(new_memory2[:, -16:].flatten()) != 0
assert sum(new_memory2[:, :-16].flatten()) == 0
assert torch.all(torch.eq(new_memory1[:, -8:], new_memory2[:, -16:-8]))
def test_combination_argmax_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='combination_argmax_sample')
data = {'obs': torch.randn(4, 3)}
shot_number = 2
output = model.forward(shot_number=shot_number, inputs=data)
assert output['action'].shape == (4, shot_number)
assert (output['action'] >= 0).all() and (output['action'] < 64).all()
def test_combination_multinomial_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='combination_multinomial_sample')
data = {'obs': torch.randn(4, 3)}
shot_number = 2
output = model.forward(shot_number=shot_number, inputs=data)
assert output['action'].shape == (4, shot_number)
assert (output['action'] >= 0).all() and (output['action'] < 64).all()