Spaces:
Runtime error
Runtime error
File size: 9,405 Bytes
c80917c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
# Implementation for paper 'Attention on Attention for Image Captioning'
# https://arxiv.org/abs/1908.06954
# RT: Code from original author's repo: https://github.com/husthuaan/AoANet/
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from .AttModel import pack_wrapper, AttModel, Attention
from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward
class MultiHeadedDotAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3):
super(MultiHeadedDotAttention, self).__init__()
assert d_model * scale % h == 0
# We assume d_v always equals d_k
self.d_k = d_model * scale // h
self.h = h
# Do we need to do linear projections on K and V?
self.project_k_v = project_k_v
# normalize the query?
if norm_q:
self.norm = LayerNorm(d_model)
else:
self.norm = lambda x:x
self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v)
# output linear layer after the multi-head attention?
self.output_layer = nn.Linear(d_model * scale, d_model)
# apply aoa after attention?
self.use_aoa = do_aoa
if self.use_aoa:
self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU())
# dropout to the input of AoA layer
if dropout_aoa > 0:
self.dropout_aoa = nn.Dropout(p=dropout_aoa)
else:
self.dropout_aoa = lambda x:x
if self.use_aoa or not use_output_layer:
# AoA doesn't need the output linear layer
del self.output_layer
self.output_layer = lambda x:x
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, value, key, mask=None):
if mask is not None:
if len(mask.size()) == 2:
mask = mask.unsqueeze(-2)
# Same mask applied to all h heads.
mask = mask.unsqueeze(1)
single_query = 0
if len(query.size()) == 2:
single_query = 1
query = query.unsqueeze(1)
nbatches = query.size(0)
query = self.norm(query)
# Do all the linear projections in batch from d_model => h x d_k
if self.project_k_v == 0:
query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
else:
query_, key_, value_ = \
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))]
# Apply attention on all the projected vectors in batch.
x, self.attn = attention(query_, key_, value_, mask=mask,
dropout=self.dropout)
# "Concat" using a view
x = x.transpose(1, 2).contiguous() \
.view(nbatches, -1, self.h * self.d_k)
if self.use_aoa:
# Apply AoA
x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1)))
x = self.output_layer(x)
if single_query:
query = query.squeeze(1)
x = x.squeeze(1)
return x
class AoA_Refiner_Layer(nn.Module):
def __init__(self, size, self_attn, feed_forward, dropout):
super(AoA_Refiner_Layer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.use_ff = 0
if self.feed_forward is not None:
self.use_ff = 1
self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff)
self.size = size
def forward(self, x, mask):
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x
class AoA_Refiner_Core(nn.Module):
def __init__(self, opt):
super(AoA_Refiner_Core, self).__init__()
attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3))
layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1)
self.layers = clones(layer, 6)
self.norm = LayerNorm(layer.size)
def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class AoA_Decoder_Core(nn.Module):
def __init__(self, opt):
super(AoA_Decoder_Core, self).__init__()
self.drop_prob_lm = opt.drop_prob_lm
self.d_model = opt.rnn_size
self.use_multi_head = opt.use_multi_head
self.multi_head_scale = opt.multi_head_scale
self.use_ctx_drop = getattr(opt, 'ctx_drop', 0)
self.out_res = getattr(opt, 'out_res', 0)
self.decoder_type = getattr(opt, 'decoder_type', 'AoA')
self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1
self.out_drop = nn.Dropout(self.drop_prob_lm)
if self.decoder_type == 'AoA':
# AoA layer
self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU())
elif self.decoder_type == 'LSTM':
# LSTM layer
self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size)
else:
# Base linear layer
self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU())
# if opt.use_multi_head == 1: # TODO, not implemented for now
# self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale)
if opt.use_multi_head == 2:
self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1)
else:
self.attention = Attention(opt)
if self.use_ctx_drop:
self.ctx_drop = nn.Dropout(self.drop_prob_lm)
else:
self.ctx_drop = lambda x :x
def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None):
# state[0][1] is the context vector at the last step
h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0]))
if self.use_multi_head == 2:
att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks)
else:
att = self.attention(h_att, att_feats, p_att_feats, att_masks)
ctx_input = torch.cat([att, h_att], 1)
if self.decoder_type == 'LSTM':
output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1]))
state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic)))
else:
output = self.att2ctx(ctx_input)
# save the context vector to state[0][1]
state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1])))
if self.out_res:
# add residual connection
output = output + h_att
output = self.out_drop(output)
return output, state
class AoAModel(AttModel):
def __init__(self, opt):
super(AoAModel, self).__init__(opt)
self.num_layers = 2
# mean pooling
self.use_mean_feats = getattr(opt, 'mean_feats', 1)
if opt.use_multi_head == 2:
del self.ctx2att
self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size)
if self.use_mean_feats:
del self.fc_embed
if opt.refine:
self.refiner = AoA_Refiner_Core(opt)
else:
self.refiner = lambda x,y : x
self.core = AoA_Decoder_Core(opt)
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
def _prepare_feature(self, fc_feats, att_feats, att_masks):
att_feats, att_masks = self.clip_att(att_feats, att_masks)
# embed att feats
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
att_feats = self.refiner(att_feats, att_masks)
if self.use_mean_feats:
# meaning pooling
if att_masks is None:
mean_feats = torch.mean(att_feats, dim=1)
else:
mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1))
else:
mean_feats = self.fc_embed(fc_feats)
# Project the attention feats first to reduce memory and computation.
p_att_feats = self.ctx2att(att_feats)
return mean_feats, att_feats, p_att_feats, att_masks |