import torch import torch.nn as nn import torch.nn.functional as F from modules.FGA.atten import Atten class FGA(nn.Module): def __init__(self, vocab_size, word_embed_dim, hidden_ques_dim, hidden_ans_dim, hidden_hist_dim, hidden_cap_dim, hidden_img_dim): ''' Factor Graph Attention :param vocab_size: vocabulary size :param word_embed_dim :param hidden_ques_dim: :param hidden_ans_dim: :param hidden_hist_dim: :param img_features_dim: ''' super(FGA, self).__init__() print("Init FGA with vocab size %s, word embed %s, hidden ques %s, hidden ans %s," " hidden hist %s, hidden cap %s, hidden img %s" % (vocab_size, word_embed_dim, hidden_ques_dim, hidden_ans_dim, hidden_hist_dim, hidden_cap_dim, hidden_img_dim)) self.hidden_ques_dim = hidden_ques_dim self.hidden_ans_dim = hidden_ans_dim self.hidden_cap_dim = hidden_cap_dim self.hidden_img_dim = hidden_img_dim self.hidden_hist_dim = hidden_hist_dim # Vocab of History LSTMs is one more as we are keeping a stop id (the last id) self.word_embedddings = nn.Embedding(vocab_size+1+1, word_embed_dim, padding_idx=0) self.lstm_ques = nn.LSTM(word_embed_dim, self.hidden_ques_dim, batch_first=True) self.lstm_ans = nn.LSTM(word_embed_dim, self.hidden_ans_dim, batch_first=True) self.lstm_hist_ques = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True) self.lstm_hist_ans = nn.LSTM(word_embed_dim, self.hidden_hist_dim, batch_first=True) self.lstm_hist_cap = nn.LSTM(word_embed_dim, self.hidden_cap_dim, batch_first=True) self.qahistnet = nn.Sequential( nn.Linear(self.hidden_hist_dim*2, self.hidden_hist_dim), nn.ReLU(inplace=True) ) self.concat_dim = self.hidden_ques_dim + self.hidden_ans_dim + \ self.hidden_ans_dim + self.hidden_img_dim + \ self.hidden_cap_dim + self.hidden_hist_dim*9 self.simnet = nn.Sequential( nn.Linear(self.concat_dim, (self.concat_dim)//2, bias=False), nn.BatchNorm1d((self.concat_dim) // 2), nn.ReLU(inplace=True), nn.Linear((self.concat_dim)//2, (self.concat_dim)//4, bias=False), nn.BatchNorm1d((self.concat_dim) // 4), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear((self.concat_dim)//4, 1) ) # To share weights, provide list of tuples: (idx, list of connected utils) # Note, for efficiency, the shared utils (i.e., history, are connected to ans and question only. # connecting shared factors is not supported (!) sharing_factor_weights = {4: (9, [0, 1]), 5: (9, [0, 1])} self.mul_atten = Atten(util_e=[self.hidden_ans_dim, # Answer modal self.hidden_ques_dim, # Question modal self.hidden_cap_dim, # Caption modal self.hidden_img_dim, # Image modal self.hidden_hist_dim, # Question-history modal self.hidden_hist_dim # Answer-history modal ], sharing_factor_weights=sharing_factor_weights, sizes=[100, # 100 Answers 21, # Question length 41, # Caption length 37, # 36 Image regions 21, # History-Question length 21 # History-Answer length ] # The spatial dim used for pairwise normalization (use force for adaptive) , prior_flag=True, pairwise_flag=True) def forward(self, input_ques, input_ans, input_hist_ques, input_hist_ans, input_hist_cap, input_ques_length, input_ans_length, input_cap_length, i_e): """ :param input_ques: :param input_ans: :param input_hist_ques: :param input_hist_ans: :param input_hist_cap: :param input_ques_length: :param input_ans_length: :param input_cap_length: :param i_e: :return: """ n_options = input_ans.size()[1] batch_size = input_ques.size()[0] nqa_per_dial, nwords_per_qa = input_hist_ques.size()[1], input_hist_ques.size()[2] nwords_per_cap = input_hist_cap.size()[1] max_length_input_ans = input_ans.size()[-1] assert batch_size == input_hist_ques.size()[0] == input_hist_ans.size()[0] == input_ques.size()[0] == \ input_ans.size()[0] == input_hist_cap.size()[0] assert nqa_per_dial == input_hist_ques.size()[1] == input_hist_ans.size()[1] assert nwords_per_qa == input_hist_ques.size()[2] == input_hist_ans.size()[2] q_we = self.word_embedddings(input_ques) a_we = self.word_embedddings(input_ans.view(-1, max_length_input_ans)) hq_we = self.word_embedddings(input_hist_ques.view(-1, nwords_per_qa)) ha_we = self.word_embedddings(input_hist_ans.view(-1, nwords_per_qa)) c_we = self.word_embedddings(input_hist_cap.view(-1, nwords_per_cap)) ''' q_we = batch x 20 x embed_ques_dim a_we = 100*batch x 20 x embed_ans_dim hq_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim ha_we = batch*nqa_per_dial, nwords_per_qa, embed_hist_dim c_we = batch*ncap_per_dial, nwords_per_cap, embed_hist_dim ''' self.lstm_ques.flatten_parameters() self.lstm_ans.flatten_parameters() self.lstm_hist_ques.flatten_parameters() self.lstm_hist_ans.flatten_parameters() self.lstm_hist_cap.flatten_parameters() i_feat = i_e q_seq, self.hidden_ques = self.lstm_ques(q_we) a_seq, self.hidden_ans = self.lstm_ans(a_we) hq_seq, self.hidden_hist_ques = self.lstm_hist_ques(hq_we) ha_seq, self.hidden_hist_ans = self.lstm_hist_ans(ha_we) cap_seq, self.hidden_cap = self.lstm_hist_cap(c_we) ''' length is used for attention prior ''' q_len = input_ques_length.data - 1 c_len = input_cap_length.data.view(-1) - 1 ans_index = torch.arange(0, n_options * batch_size).long().cuda() ans_len = input_ans_length.data.view(-1) - 1 ans_seq = a_seq[ans_index, ans_len, :] ans_seq = ans_seq.view(batch_size, n_options, self.hidden_ans_dim) batch_index = torch.arange(0, batch_size).long().cuda() q_prior = torch.zeros(batch_size, q_seq.size(1)).cuda() q_prior[batch_index, q_len] = 100 c_prior = torch.zeros(batch_size, cap_seq.size(1)).cuda() c_prior[batch_index, c_len] = 100 ans_prior = torch.ones(batch_size, ans_seq.size(1)).cuda() img_prior = torch.ones(batch_size, i_feat.size(1)).cuda() (ans_atten, ques_atten, cap_atten, img_atten, hq_atten, ha_atten) = \ self.mul_atten([ans_seq, q_seq, cap_seq, i_feat, hq_seq, ha_seq], priors=[ans_prior, q_prior, c_prior, img_prior, None, None]) ''' expand to answers based ''' ques_atten = torch.unsqueeze(ques_atten, 1).expand(batch_size, n_options, self.hidden_ques_dim) cap_atten = torch.unsqueeze(cap_atten, 1).expand(batch_size, n_options, self.hidden_cap_dim) img_atten = torch.unsqueeze(img_atten, 1).expand(batch_size, n_options, self.hidden_img_dim) ans_atten = torch.unsqueeze(ans_atten, 1).expand(batch_size, n_options, self.hidden_ans_dim) ''' combine history ''' input_qahistnet = torch.cat((hq_atten, ha_atten), 1) # input_qahistnet: (nqa_per_dial*batch x 2*hidden_hist_dim) output_qahistnet = self.qahistnet(input_qahistnet) # output_qahistnet: (nqa_per_dial*batch x hidden_hist_dim) output_qahistnet = output_qahistnet.view(batch_size, nqa_per_dial * self.hidden_hist_dim) # output_qahistnet: (batch x nqa_per_dial*hidden_hist_dim) output_qahistnet = torch.unsqueeze(output_qahistnet, 1)\ .expand(batch_size, n_options, nqa_per_dial * self.hidden_hist_dim) input_qa = torch.cat((ans_seq, ques_atten, ans_atten, img_atten, output_qahistnet, cap_atten), 2) # Concatenate last dimension input_qa = input_qa.view(batch_size * n_options, self.concat_dim) out_scores = self.simnet(input_qa) out_scores = out_scores.squeeze(dim=1) out_scores = out_scores.view(batch_size, n_options) return out_scores