Spaces:
Running
Running
#!/usr/bin/env python | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from itertools import product, permutations, combinations_with_replacement, chain | |
class Unary(nn.Module): | |
def __init__(self, embed_size): | |
""" | |
Captures local entity information | |
:param embed_size: the embedding dimension | |
""" | |
super(Unary, self).__init__() | |
self.embed = nn.Conv1d(embed_size, embed_size, 1) | |
self.feature_reduce = nn.Conv1d(embed_size, 1, 1) | |
def forward(self, X): | |
X = X.transpose(1, 2) | |
X_embed = self.embed(X) | |
X_nl_embed = F.dropout(F.relu(X_embed), training=self.training) | |
X_poten = self.feature_reduce(X_nl_embed) | |
return X_poten.squeeze(1) | |
class Pairwise(nn.Module): | |
def __init__(self, embed_x_size, x_spatial_dim=None, embed_y_size=None, y_spatial_dim=None): | |
""" | |
Captures interaction between utilities or entities of the same utility | |
:param embed_x_size: the embedding dimension of the first utility | |
:param x_spatial_dim: the spatial dimension of the first utility for batch norm and weighted marginalization | |
:param embed_y_size: the embedding dimension of the second utility (none for self-interactions) | |
:param y_spatial_dim: the spatial dimension of the second utility for batch norm and weighted marginalization | |
""" | |
super(Pairwise, self).__init__() | |
embed_y_size = embed_y_size if y_spatial_dim is not None else embed_x_size | |
self.y_spatial_dim = y_spatial_dim if y_spatial_dim is not None else x_spatial_dim | |
self.embed_size = max(embed_x_size, embed_y_size) | |
self.x_spatial_dim = x_spatial_dim | |
self.embed_X = nn.Conv1d(embed_x_size, self.embed_size, 1) | |
self.embed_Y = nn.Conv1d(embed_y_size, self.embed_size, 1) | |
if x_spatial_dim is not None: | |
self.normalize_S = nn.BatchNorm1d(self.x_spatial_dim * self.y_spatial_dim) | |
self.margin_X = nn.Conv1d(self.y_spatial_dim, 1, 1) | |
self.margin_Y = nn.Conv1d(self.x_spatial_dim, 1, 1) | |
def forward(self, X, Y=None): | |
X_t = X.transpose(1, 2) | |
Y_t = Y.transpose(1, 2) if Y is not None else X_t | |
X_embed = self.embed_X(X_t) | |
Y_embed = self.embed_Y(Y_t) | |
X_norm = F.normalize(X_embed) | |
Y_norm = F.normalize(Y_embed) | |
S = X_norm.transpose(1, 2).bmm(Y_norm) | |
if self.x_spatial_dim is not None: | |
S = self.normalize_S(S.view(-1, self.x_spatial_dim * self.y_spatial_dim)) \ | |
.view(-1, self.x_spatial_dim, self.y_spatial_dim) | |
X_poten = self.margin_X(S.transpose(1, 2)).transpose(1, 2).squeeze(2) | |
Y_poten = self.margin_Y(S).transpose(1, 2).squeeze(2) | |
else: | |
X_poten = S.mean(dim=2, keepdim=False) | |
Y_poten = S.mean(dim=1, keepdim=False) | |
if Y is None: | |
return X_poten | |
else: | |
return X_poten, Y_poten | |
class Atten(nn.Module): | |
def __init__(self, util_e, sharing_factor_weights=[], prior_flag=False, | |
sizes=[], size_force=False, pairwise_flag=True, | |
unary_flag=True, self_flag=True): | |
""" | |
The class performs an attention on a given list of utilities representation. | |
:param util_e: the embedding dimensions | |
:param sharing_factor_weights: To share weights, provide a dict of tuples: | |
{idx: (num_utils, connected utils) | |
Note, for efficiency, the shared utils (i.e., history, are connected to ans | |
and question only. | |
TODO: connections between shared utils | |
:param prior_flag: is prior factor provided | |
:param sizes: the spatial simension (used for batch-norm and weighted marginalization) | |
:param size_force: force spatial size with adaptive avg pooling. | |
:param pairwise_flag: use pairwise interaction between utilities | |
:param unary_flag: use local information | |
:param self_flag: use self interactions between utilitie's entities | |
""" | |
super(Atten, self).__init__() | |
self.util_e = util_e | |
self.prior_flag = prior_flag | |
self.n_utils = len(util_e) | |
self.spatial_pool = nn.ModuleDict() | |
self.un_models = nn.ModuleList() | |
self.self_flag = self_flag | |
self.pairwise_flag = pairwise_flag | |
self.unary_flag = unary_flag | |
self.size_force = size_force | |
if len(sizes) == 0: | |
sizes = [None for _ in util_e] | |
self.sharing_factor_weights = sharing_factor_weights | |
#force the provided size | |
for idx, e_dim in enumerate(util_e): | |
self.un_models.append(Unary(e_dim)) | |
if self.size_force: | |
self.spatial_pool[str(idx)] = nn.AdaptiveAvgPool1d(sizes[idx]) | |
#Pairwise | |
self.pp_models = nn.ModuleDict() | |
for ((idx1, e_dim_1), (idx2, e_dim_2)) \ | |
in combinations_with_replacement(enumerate(util_e), 2): | |
# self | |
if self.self_flag and idx1 == idx2: | |
self.pp_models[str(idx1)] = Pairwise(e_dim_1, sizes[idx1]) | |
else: | |
if pairwise_flag: | |
if idx1 in self.sharing_factor_weights: | |
# not connected | |
if idx2 not in self.sharing_factor_weights[idx1][1]: | |
continue | |
if idx2 in self.sharing_factor_weights: | |
# not connected | |
if idx1 not in self.sharing_factor_weights[idx2][1]: | |
continue | |
self.pp_models[str((idx1, idx2))] = Pairwise(e_dim_1, sizes[idx1], e_dim_2, sizes[idx2]) | |
# Handle reduce potentials (with scalars) | |
self.reduce_potentials = nn.ModuleList() | |
self.num_of_potentials = dict() | |
self.default_num_of_potentials = 0 | |
if self.self_flag: | |
self.default_num_of_potentials += 1 | |
if self.unary_flag: | |
self.default_num_of_potentials += 1 | |
if self.prior_flag: | |
self.default_num_of_potentials += 1 | |
for idx in range(self.n_utils): | |
self.num_of_potentials[idx] = self.default_num_of_potentials | |
''' | |
All other utilities | |
''' | |
if pairwise_flag: | |
for idx, (num_utils, connected_utils) in sharing_factor_weights: | |
for c_u in connected_utils: | |
self.num_of_potentials[c_u] += num_utils | |
self.num_of_potentials[idx] += 1 | |
for k in self.num_of_potentials: | |
if k not in self.sharing_factor_weights: | |
self.num_of_potentials[k] += (self.n_utils - 1) \ | |
- len(sharing_factor_weights) | |
for idx in range(self.n_utils): | |
self.reduce_potentials.append(nn.Conv1d(self.num_of_potentials[idx], | |
1, 1, bias=False)) | |
def forward(self, utils, priors=None): | |
assert self.n_utils == len(utils) | |
assert (priors is None and not self.prior_flag) \ | |
or (priors is not None | |
and self.prior_flag | |
and len(priors) == self.n_utils) | |
b_size = utils[0].size(0) | |
util_factors = dict() | |
attention = list() | |
#Force size, constant size is used for pairwise batch normalization | |
if self.size_force: | |
for i, (num_utils, _) in self.sharing_factor_weights.items(): | |
if str(i) not in self.spatial_pool.keys(): | |
continue | |
else: | |
high_util = utils[i] | |
high_util = high_util.view(num_utils * b_size, high_util.size(2), high_util.size(3)) | |
high_util = high_util.transpose(1, 2) | |
utils[i] = self.spatial_pool[str(i)](high_util).transpose(1, 2) | |
for i in range(self.n_utils): | |
if i in self.sharing_factor_weights \ | |
or str(i) not in self.spatial_pool.keys(): | |
continue | |
utils[i] = utils[i].transpose(1, 2) | |
utils[i] = self.spatial_pool[str(i)](utils[i]).transpose(1, 2) | |
if self.prior_flag and priors[i] is not None: | |
priors[i] = self.spatial_pool[str(i)](priors[i].unsqueeze(1)).squeeze(1) | |
# handle Shared weights | |
for i, (num_utils, connected_list) in self.sharing_factor_weights: | |
if self.unary_flag: | |
util_factors.setdefault(i, []).append(self.un_models[i](utils[i])) | |
if self.self_flag: | |
util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i])) | |
if self.pairwise_flag: | |
for j in connected_list: | |
other_util = utils[j] | |
expanded_util = other_util.unsqueeze(1).expand(b_size, | |
num_utils, | |
other_util.size(1), | |
other_util.size(2)).contiguous().view( | |
b_size * num_utils, | |
other_util.size(1), | |
other_util.size(2)) | |
if i < j: | |
factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], expanded_util) | |
else: | |
factor_ji, factor_ij = self.pp_models[str((j, i))](expanded_util, utils[i]) | |
util_factors[i].append(factor_ij) | |
util_factors.setdefault(j, []).append(factor_ji.view(b_size, num_utils, factor_ji.size(1))) | |
# handle local factors | |
for i in range(self.n_utils): | |
if i in self.sharing_factor_weights: | |
continue | |
if self.unary_flag: | |
util_factors.setdefault(i, []).append(self.un_models[i](utils[i])) | |
if self.self_flag: | |
util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i])) | |
# joint | |
if self.pairwise_flag: | |
for (i, j) in combinations_with_replacement(range(self.n_utils), 2): | |
if i in self.sharing_factor_weights \ | |
or j in self.sharing_factor_weights: | |
continue | |
if i == j: | |
continue | |
else: | |
factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], utils[j]) | |
util_factors.setdefault(i, []).append(factor_ij) | |
util_factors.setdefault(j, []).append(factor_ji) | |
# perform attention | |
for i in range(self.n_utils): | |
if self.prior_flag: | |
prior = priors[i] \ | |
if priors[i] is not None \ | |
else torch.zeros_like(util_factors[i][0], requires_grad=False).cuda() | |
util_factors[i].append(prior) | |
util_factors[i] = torch.cat([p if len(p.size()) == 3 else p.unsqueeze(1) | |
for p in util_factors[i]], dim=1) | |
util_factors[i] = self.reduce_potentials[i](util_factors[i]).squeeze(1) | |
util_factors[i] = F.softmax(util_factors[i], dim=1).unsqueeze(2) | |
attention.append(torch.bmm(utils[i].transpose(1, 2), util_factors[i]).squeeze(2)) | |
return attention | |
class NaiveAttention(nn.Module): | |
def __init__(self): | |
""" | |
Used for ablation analysis - removing attention. | |
""" | |
super(NaiveAttention, self).__init__() | |
def forward(self, utils, priors): | |
atten = [] | |
spatial_atten = [] | |
for u, p in zip(utils, priors): | |
if type(u) is tuple: | |
u = u[1] | |
num_elements = u.shape[0] | |
if p is not None: | |
u = u.view(-1, u.shape[-2], u.shape[-1]) | |
p = p.view(-1, p.shape[-2], p.shape[-1]) | |
spatial_atten.append( | |
torch.bmm(p.transpose(1, 2), u).squeeze(2).view(num_elements, -1, u.shape[-2], u.shape[-1])) | |
else: | |
spatial_atten.append(u.mean(2)) | |
continue | |
if p is not None: | |
atten.append(torch.bmm(u.transpose(1, 2), p.unsqueeze(2)).squeeze(2)) | |
else: | |
atten.append(u.mean(1)) | |
return atten, spatial_atten |