Spaces:
No application file
No application file
File size: 6,829 Bytes
8cb4f3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch, time
import torch.nn.functional as functional
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_sequence
from modules.inference.beam_search import BeamSearch
from utils.data import generate_language_token
import modules.constants as const
def generate_subsequent_mask(sz, device):
return torch.triu(
torch.ones(sz, sz, dtype=torch.int, device=device)
).transpose(0, 1).unsqueeze(0)
class BeamSearch2(BeamSearch):
"""
Same with BeamSearch2 class.
Difference: remove the sentence which its beams terminated (reached <eos> token) from the time step loop.
Update to reuse functions already coded in normal BeamSearch. Note that replacing unknown words & n_best is not available.
"""
def _convert_to_sent(self, sent_id, eos_token_id):
eos = torch.nonzero(sent_id == eos_token_id).view(-1)
t = eos[0] if len(eos) > 0 else len(sent_id)
return [self.TRG.vocab.itos[j] for j in sent_id[1 : t]]
@torch.no_grad()
def beam_search(self, src, src_lang=None, trg_lang=None, src_tokens=None, n_best=1, debug=False):
"""
Beam search select k words with the highest conditional probability
to be the first word of the k candidate output sequences.
Args:
src: The batch of sentences, already in [batch_size, tokens] of int
src_tokens: src in str version, same size as above
n_best: number of usable values per beam loaded (Not implemented)
debug: if true, print some debug information during the search
Return:
An array of translated sentences, in list-of-tokens format. TODO convert [batch_size, n_best, tgt_len] instead of [batch_size, tgt_len]
"""
# Create some local variable
src_field, trg_field = self.SRC, self.TRG
sos_token = generate_language_token(trg_lang) if trg_lang is not None else const.DEFAULT_SOS
init_token = trg_field.vocab.stoi[sos_token]
eos_token_id = trg_field.vocab.stoi[const.DEFAULT_EOS]
src = src.to(self.device)
batch_size = src.size(0)
model = self.model
k = self.beam_size
max_len = self.max_len
device = self.device
# Initialize target sequence (start with '<sos>' token) [batch_size x k x max_len]
trg = torch.zeros(batch_size, k, max_len, device=device).long()
trg[:, :, 0] = init_token
# Precalc output from model's encoder
single_src_mask = (src != src_field.vocab.stoi['<pad>']).unsqueeze(1).to(device)
e_out = model.encoder(src, single_src_mask) # [batch_size x S x d_model]
# Output model prob
trg_mask = generate_subsequent_mask(1, device=device)
# [batch_size x 1]
inp_decoder = trg[:, 0, 0].view(batch_size, 1)
# [batch_size x 1 x vocab_size]
prob = model.out(model.decoder(inp_decoder, e_out, single_src_mask, trg_mask))
prob = functional.softmax(prob, dim=-1)
# [batch_size x 1 x k]
k_prob, k_index = torch.topk(prob, k, dim=-1)
trg[:, :, 1] = k_index.view(batch_size, k)
# Init log scores from k beams [batch_size x k x 1]
log_scores = torch.log(k_prob.view(batch_size, k, 1))
# Repeat encoder's output k times for searching [(k * batch_size) x S x d_model]
e_outs = torch.repeat_interleave(e_out, k, dim=0)
src_mask = torch.repeat_interleave(single_src_mask, k, dim=0)
# Create mask for checking eos
sent_eos = torch.tensor([eos_token_id for _ in range(k)], device=device).view(1, k)
# The batch indexes
batch_index = torch.arange(batch_size)
finished_batches = torch.zeros(batch_size, device=device).long()
# Iteratively searching
for i in range(2, max_len):
trg_mask = generate_subsequent_mask(i, device)
# Flatten trg tensor for feeding into model [(k * batch_size) x i]
inp_decoder = trg[batch_index, :, :i].view(k * len(batch_index), i)
# Output model prob [(k * batch_size) x i x vocab_size]
prob = model.out(model.decoder(inp_decoder, e_outs, src_mask, trg_mask))
prob = functional.softmax(prob, dim=-1)
# Only care the last prob i-th
# [(k * batch_size) x 1 x vocab_size]
prob = prob[:, i-1, :].view(k * len(batch_index), 1, -1)
# Truncate prob to top k [(k * batch_size) x 1 x k]
k_prob, k_index = prob.data.topk(k, dim=-1)
# Deflatten k_prob & k_index
k_prob = k_prob.view(len(batch_index), k, 1, k)
k_index = k_index.view(len(batch_index), k, 1, k)
# Preserve eos beams
# [batch_size x k] -> view -> [batch_size x k x 1 x 1] (broadcastable)
eos_mask = (trg[batch_index, :, i-1] == eos_token_id).view(len(batch_index), k, 1, 1)
k_prob.masked_fill_(eos_mask, 1.0)
k_index.masked_fill_(eos_mask, eos_token_id)
# Find the best k cases
# Compute log score at i-th timestep
# [batch_size x k x 1 x 1] + [batch_size x k x 1 x k] = [batch_size x k x 1 x k]
combine_probs = log_scores[batch_index].unsqueeze(-1) + torch.log(k_prob)
# [batch_size x k x 1]
log_scores[batch_index], positions = torch.topk(combine_probs.view(len(batch_index), k * k, 1), k, dim=1)
# The rows selected from top k
rows = positions.view(len(batch_index), k) // k
# The indexes in vocab respected to these rows
cols = positions.view(len(batch_index), k) % k
batch_sim = torch.arange(len(batch_index)).view(-1, 1)
trg[batch_index, :, :] = trg[batch_index.view(-1, 1), rows, :]
trg[batch_index, :, i] = k_index[batch_sim, rows, :, cols].view(len(batch_index), k)
# Update which sentences finished all its beams
mask = (trg[:, :, i] == sent_eos).all(1).view(-1).to(device)
finished_batches.masked_fill_(mask, value=1)
cnt = torch.sum(finished_batches).item()
if cnt == batch_size:
break
# Continue with remaining batches (if any)
batch_index = torch.nonzero(finished_batches == 0).view(-1)
e_outs = torch.repeat_interleave(e_out[batch_index], k, dim=0)
src_mask = torch.repeat_interleave(single_src_mask[batch_index], k, dim=0)
# End loop
# Get the best beam
log_scores = log_scores.view(batch_size, k)
results = [self._convert_to_sent(trg[t, j.item(), :], eos_token_id) for t, j in enumerate(torch.argmax(log_scores, dim=-1))]
return results
|