File size: 12,607 Bytes
b20af9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
#!/usr/bin/env python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from itertools import product, permutations, combinations_with_replacement, chain


class Unary(nn.Module):
    def __init__(self, embed_size):
        """
            Captures local entity information
        :param embed_size:  the embedding dimension
        """
        super(Unary, self).__init__()
        self.embed = nn.Conv1d(embed_size, embed_size, 1)
        self.feature_reduce = nn.Conv1d(embed_size, 1, 1)

    def forward(self, X):
        X = X.transpose(1, 2)

        X_embed = self.embed(X)

        X_nl_embed = F.dropout(F.relu(X_embed), training=self.training)
        X_poten = self.feature_reduce(X_nl_embed)
        return X_poten.squeeze(1)


class Pairwise(nn.Module):
    def __init__(self, embed_x_size, x_spatial_dim=None, embed_y_size=None, y_spatial_dim=None):
        """
            Captures interaction between utilities or entities of the same utility
        :param embed_x_size: the embedding dimension of the first utility
        :param x_spatial_dim: the spatial dimension of the first utility for batch norm and weighted marginalization
        :param embed_y_size: the embedding dimension of the second utility (none for self-interactions)
        :param y_spatial_dim: the spatial dimension of the second utility for batch norm and weighted marginalization
        """

        super(Pairwise, self).__init__()
        embed_y_size = embed_y_size if y_spatial_dim is not None else embed_x_size
        self.y_spatial_dim = y_spatial_dim if y_spatial_dim is not None else x_spatial_dim

        self.embed_size = max(embed_x_size, embed_y_size)
        self.x_spatial_dim = x_spatial_dim

        self.embed_X = nn.Conv1d(embed_x_size, self.embed_size, 1)
        self.embed_Y = nn.Conv1d(embed_y_size, self.embed_size, 1)
        if x_spatial_dim is not None:
            self.normalize_S = nn.BatchNorm1d(self.x_spatial_dim * self.y_spatial_dim)

            self.margin_X = nn.Conv1d(self.y_spatial_dim, 1, 1)
            self.margin_Y = nn.Conv1d(self.x_spatial_dim, 1, 1)

    def forward(self, X, Y=None):

        X_t = X.transpose(1, 2)
        Y_t = Y.transpose(1, 2) if Y is not None else X_t


        X_embed = self.embed_X(X_t)
        Y_embed = self.embed_Y(Y_t)

        X_norm = F.normalize(X_embed)
        Y_norm = F.normalize(Y_embed)

        S = X_norm.transpose(1, 2).bmm(Y_norm)
        if self.x_spatial_dim is not None:
            S = self.normalize_S(S.view(-1, self.x_spatial_dim * self.y_spatial_dim)) \
                .view(-1, self.x_spatial_dim, self.y_spatial_dim)

            X_poten = self.margin_X(S.transpose(1, 2)).transpose(1, 2).squeeze(2)
            Y_poten = self.margin_Y(S).transpose(1, 2).squeeze(2)
        else:
            X_poten = S.mean(dim=2, keepdim=False)
            Y_poten = S.mean(dim=1, keepdim=False)

        if Y is None:
            return X_poten
        else:
            return X_poten, Y_poten


class Atten(nn.Module):
    def __init__(self, util_e, sharing_factor_weights=[], prior_flag=False,
                 sizes=[], size_force=False, pairwise_flag=True,
                 unary_flag=True, self_flag=True):
        """
            The class performs an attention on a given list of utilities representation.
        :param util_e: the embedding dimensions
        :param sharing_factor_weights:  To share weights, provide a dict of tuples:
         {idx: (num_utils, connected utils)
         Note, for efficiency, the shared utils (i.e., history, are connected to ans
          and question only.
         TODO: connections between shared utils
        :param prior_flag: is prior factor provided
        :param sizes: the spatial simension (used for batch-norm and weighted marginalization)
        :param size_force: force spatial size with adaptive avg pooling.
        :param pairwise_flag: use pairwise interaction between utilities
        :param unary_flag: use local information
        :param self_flag: use self interactions between utilitie's entities
        """
        super(Atten, self).__init__()
        self.util_e = util_e

        self.prior_flag = prior_flag

        self.n_utils = len(util_e)

        self.spatial_pool = nn.ModuleDict()

        self.un_models = nn.ModuleList()

        self.self_flag = self_flag
        self.pairwise_flag = pairwise_flag
        self.unary_flag = unary_flag
        self.size_force = size_force

        if len(sizes) == 0:
            sizes = [None for _ in util_e]

        self.sharing_factor_weights = sharing_factor_weights

        #force the provided size
        for idx, e_dim in enumerate(util_e):
            self.un_models.append(Unary(e_dim))
            if self.size_force:
                self.spatial_pool[str(idx)] = nn.AdaptiveAvgPool1d(sizes[idx])

        #Pairwise
        self.pp_models = nn.ModuleDict()
        for ((idx1, e_dim_1), (idx2, e_dim_2)) \
                in combinations_with_replacement(enumerate(util_e), 2):
            # self
            if self.self_flag and idx1 == idx2:
                self.pp_models[str(idx1)] = Pairwise(e_dim_1, sizes[idx1])
            else:
                if pairwise_flag:
                    if idx1 in self.sharing_factor_weights:
                        # not connected
                        if idx2 not in self.sharing_factor_weights[idx1][1]:
                           continue
                    if idx2 in self.sharing_factor_weights:
                        # not connected
                        if idx1 not in self.sharing_factor_weights[idx2][1]:
                            continue
                    self.pp_models[str((idx1, idx2))] = Pairwise(e_dim_1, sizes[idx1], e_dim_2, sizes[idx2])

        # Handle reduce potentials (with scalars)
        self.reduce_potentials = nn.ModuleList()

        self.num_of_potentials = dict()

        self.default_num_of_potentials = 0

        if self.self_flag:
            self.default_num_of_potentials += 1
        if self.unary_flag:
            self.default_num_of_potentials += 1
        if self.prior_flag:
            self.default_num_of_potentials += 1
        for idx in range(self.n_utils):
            self.num_of_potentials[idx] = self.default_num_of_potentials

        '''
         All other utilities
        '''
        if pairwise_flag:
            for idx, (num_utils, connected_utils) in sharing_factor_weights:
                for c_u in connected_utils:
                    self.num_of_potentials[c_u] += num_utils
                    self.num_of_potentials[idx] += 1
            for k in self.num_of_potentials:
                if k not in self.sharing_factor_weights:
                    self.num_of_potentials[k] += (self.n_utils - 1) \
                                                 - len(sharing_factor_weights)

        for idx in range(self.n_utils):
            self.reduce_potentials.append(nn.Conv1d(self.num_of_potentials[idx],
                                                    1, 1, bias=False))

    def forward(self, utils, priors=None):
        assert self.n_utils == len(utils)
        assert (priors is None and not self.prior_flag) \
               or (priors is not None
                   and self.prior_flag
                   and len(priors) == self.n_utils)
        b_size = utils[0].size(0)
        util_factors = dict()
        attention = list()

        #Force size, constant size is used for pairwise batch normalization
        if self.size_force:
            for i, (num_utils, _) in self.sharing_factor_weights.items():
                if str(i) not in self.spatial_pool.keys():
                    continue
                else:
                    high_util = utils[i]
                    high_util = high_util.view(num_utils * b_size, high_util.size(2), high_util.size(3))
                    high_util = high_util.transpose(1, 2)
                    utils[i] = self.spatial_pool[str(i)](high_util).transpose(1, 2)

            for i in range(self.n_utils):
                if i in self.sharing_factor_weights \
                        or str(i) not in self.spatial_pool.keys():
                    continue
                utils[i] = utils[i].transpose(1, 2)
                utils[i] = self.spatial_pool[str(i)](utils[i]).transpose(1, 2)
                if self.prior_flag and priors[i] is not None:
                    priors[i] = self.spatial_pool[str(i)](priors[i].unsqueeze(1)).squeeze(1)

        # handle Shared weights
        for i, (num_utils, connected_list) in self.sharing_factor_weights:
            if self.unary_flag:
                util_factors.setdefault(i, []).append(self.un_models[i](utils[i]))

            if self.self_flag:
                util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i]))

            if self.pairwise_flag:
                for j in connected_list:
                    other_util = utils[j]
                    expanded_util = other_util.unsqueeze(1).expand(b_size,
                                                                   num_utils,
                                                                   other_util.size(1),
                                                                   other_util.size(2)).contiguous().view(
                        b_size * num_utils,
                        other_util.size(1),
                        other_util.size(2))

                    if i < j:
                        factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], expanded_util)
                    else:
                        factor_ji, factor_ij = self.pp_models[str((j, i))](expanded_util, utils[i])
                    util_factors[i].append(factor_ij)
                    util_factors.setdefault(j, []).append(factor_ji.view(b_size, num_utils, factor_ji.size(1)))

        # handle local factors
        for i in range(self.n_utils):
            if i in self.sharing_factor_weights:
                continue
            if self.unary_flag:
                util_factors.setdefault(i, []).append(self.un_models[i](utils[i]))
            if self.self_flag:
                util_factors.setdefault(i, []).append(self.pp_models[str(i)](utils[i]))

        # joint
        if self.pairwise_flag:
            for (i, j) in combinations_with_replacement(range(self.n_utils), 2):
                if i in self.sharing_factor_weights \
                        or j in self.sharing_factor_weights:
                    continue
                if i == j:
                    continue
                else:
                    factor_ij, factor_ji = self.pp_models[str((i, j))](utils[i], utils[j])
                    util_factors.setdefault(i, []).append(factor_ij)
                    util_factors.setdefault(j, []).append(factor_ji)

        # perform attention
        for i in range(self.n_utils):
            if self.prior_flag:
                prior = priors[i] \
                    if priors[i] is not None \
                    else torch.zeros_like(util_factors[i][0], requires_grad=False).cuda()

                util_factors[i].append(prior)

            util_factors[i] = torch.cat([p if len(p.size()) == 3 else p.unsqueeze(1)
                                       for p in util_factors[i]], dim=1)
            util_factors[i] = self.reduce_potentials[i](util_factors[i]).squeeze(1)
            util_factors[i] = F.softmax(util_factors[i], dim=1).unsqueeze(2)
            attention.append(torch.bmm(utils[i].transpose(1, 2), util_factors[i]).squeeze(2))

        return attention


class NaiveAttention(nn.Module):
    def __init__(self):
        """
            Used for ablation analysis - removing attention.
        """
        super(NaiveAttention, self).__init__()

    def forward(self, utils, priors):
        atten = []
        spatial_atten = []
        for u, p in zip(utils, priors):
            if type(u) is tuple:
                u = u[1]
                num_elements = u.shape[0]
                if p is not None:
                    u = u.view(-1, u.shape[-2], u.shape[-1])
                    p = p.view(-1, p.shape[-2], p.shape[-1])
                    spatial_atten.append(
                        torch.bmm(p.transpose(1, 2), u).squeeze(2).view(num_elements, -1, u.shape[-2], u.shape[-1]))
                else:
                    spatial_atten.append(u.mean(2))
                continue
            if p is not None:
                atten.append(torch.bmm(u.transpose(1, 2), p.unsqueeze(2)).squeeze(2))
            else:
                atten.append(u.mean(1))
        return atten, spatial_atten