GreedRL / greedrl /decode.py
先坤
add greedrl
db26c81
raw
history blame
6.49 kB
import math
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
class MultiHeadAttention(nn.Module):
def __init__(self, heads, hidden_dim):
super(MultiHeadAttention, self).__init__()
assert hidden_dim % heads == 0
self.heads = heads
head_dim = hidden_dim // heads
self.alpha = 1 / math.sqrt(head_dim)
self.nn_Q = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
self.nn_O = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
for param in self.parameters():
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, q, K, V, mask):
batch_size, query_num, hidden_dim = q.size()
size = (self.heads, batch_size, query_num, -1)
q = q.reshape(-1, hidden_dim)
Q = torch.matmul(q, self.nn_Q).view(size)
value_num = V.size(2)
heads_batch = self.heads * batch_size
Q = Q.view(heads_batch, query_num, -1)
K = K.view(heads_batch, value_num, -1).transpose(1, 2)
S = masked_tensor(mask, self.heads)
S = S.view(heads_batch, query_num, value_num)
S.baddbmm_(Q, K, alpha=self.alpha)
S = S.view(self.heads, batch_size, query_num, value_num)
S = F.softmax(S, dim=-1)
x = torch.matmul(S, V).permute(1, 2, 0, 3)
x = x.reshape(batch_size, query_num, -1)
x = torch.matmul(x, self.nn_O)
return x
class Decode(nn.Module):
def __init__(self, nn_args):
super(Decode, self).__init__()
self.nn_args = nn_args
heads = nn_args['decode_atten_heads']
hidden_dim = nn_args['decode_hidden_dim']
self.heads = heads
self.alpha = 1 / math.sqrt(hidden_dim)
if heads > 0:
assert hidden_dim % heads == 0
head_dim = hidden_dim // heads
self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
self.nn_mha = MultiHeadAttention(heads, hidden_dim)
decode_rnn = nn_args.setdefault('decode_rnn', 'LSTM')
assert decode_rnn in ('GRU', 'LSTM', 'NONE')
if decode_rnn == 'GRU':
self.nn_rnn_cell = nn.GRUCell(hidden_dim, hidden_dim)
elif decode_rnn == 'LSTM':
self.nn_rnn_cell = nn.LSTMCell(hidden_dim, hidden_dim)
else:
self.nn_rnn_cell = None
self.vars_dim = sum(nn_args['variable_dim'].values())
if self.vars_dim > 0:
atten_type = nn_args.setdefault('decode_atten_type', 'add')
assert atten_type == 'add', "must be addition attention when vars_dim > 0, {}".format(atten_type)
self.nn_A = nn.Parameter(torch.Tensor(self.vars_dim, hidden_dim))
self.nn_B = nn.Parameter(torch.Tensor(hidden_dim))
else:
atten_type = nn_args.setdefault('decode_atten_type', 'prod')
if atten_type == 'add':
self.nn_W = nn.Parameter(torch.Tensor(hidden_dim))
else:
self.nn_W = None
for param in self.parameters():
stdv = 1 / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p, clip, mode, memopt=0):
if self.training and memopt > 2:
state1, state2 = checkpoint(self.rnn_step, query, state1, state2)
else:
state1, state2 = self.rnn_step(query, state1, state2)
query = state1
NP = X.size(0)
NR = query.size(0) // NP
batch_size = query.size(0)
if self.heads > 0:
query = query.view(NP, NR, -1)
if self.training and memopt > 1:
query = checkpoint(self.nn_mha, query, K, V, mask)
else:
query = self.nn_mha(query, K, V, mask)
query = query.view(batch_size, -1)
if self.nn_W is None:
query = query.view(NP, NR, -1)
logit = masked_tensor(mask, 1)
logit = logit.view(NP, NR, -1)
X = X.permute(0, 2, 1)
logit.baddbmm_(query, X, alpha=self.alpha)
logit = logit.view(batch_size, -1)
else:
if self.training and self.vars_dim > 0 and memopt > 0:
logit = checkpoint(self.atten, query, X, varfeat, mask)
else:
logit = self.atten(query, X, varfeat, mask)
chosen_p = choose(logit, chosen, sample_p, clip, mode)
return state1, state2, chosen_p
def rnn_step(self, query, state1, state2):
if isinstance(self.nn_rnn_cell, nn.GRUCell):
state1 = self.nn_rnn_cell(query, state1)
elif isinstance(self.nn_rnn_cell, nn.LSTMCell):
state1, state2 = self.nn_rnn_cell(query, (state1, state2))
return state1, state2
def atten(self, query, keyvalue, varfeat, mask):
if self.vars_dim > 0:
varfeat = vfaddmm(varfeat, mask, self.nn_A, self.nn_B)
return atten(query, keyvalue, varfeat, mask, self.nn_W)
def choose(logit, chosen, sample_p, clip, mode):
mask = logit == -math.inf
logit = torch.tanh(logit) * clip
logit[mask] = -math.inf
if mode == 0:
pass
elif mode == 1:
chosen[:] = logit.argmax(1)
elif mode == 2:
p = logit.exp()
chosen[:] = torch.multinomial(p, 1).squeeze(1)
else:
raise Exception()
logp = logit.log_softmax(1)
logp = logp.gather(1, chosen[:, None])
logp = logp.squeeze(1)
return logp
def atten(query, keyvalue, varfeat, mask, weight):
batch_size = query.size(0)
NP, NK, ND = keyvalue.size()
query = query.view(NP, -1, 1, ND)
varfeat = varfeat.view(NP, -1, NK, ND)
keyvalue = keyvalue[:, None, :, :]
keyvalue = keyvalue + varfeat + query
keyvalue = torch.tanh(keyvalue)
keyvalue = keyvalue.view(-1, ND)
logit = masked_tensor(mask, 1).view(-1)
logit.addmv_(keyvalue, weight)
return logit.view(batch_size, -1)
def masked_tensor(mask, heads):
size = list(mask.size())
size.insert(0, heads)
mask = mask[None].expand(size)
result = mask.new_zeros(size, dtype=torch.float32)
result[mask] = -math.inf
return result
def vfaddmm(varfeat, mask, A, B):
varfeat = varfeat.permute(0, 2, 1)
return F.linear(varfeat, A.permute(1, 0), B)