File size: 6,829 Bytes
8cb4f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch, time
import torch.nn.functional as functional
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_sequence

from modules.inference.beam_search import BeamSearch
from utils.data import generate_language_token
import modules.constants as const

def generate_subsequent_mask(sz, device):
    return torch.triu(
        torch.ones(sz, sz, dtype=torch.int, device=device)
    ).transpose(0, 1).unsqueeze(0)

class BeamSearch2(BeamSearch):
    """
    Same with BeamSearch2 class.
    Difference: remove the sentence which its beams terminated (reached <eos> token) from the time step loop.
    Update to reuse functions already coded in normal BeamSearch. Note that replacing unknown words & n_best is not available.
    """
    def _convert_to_sent(self, sent_id, eos_token_id):
        eos = torch.nonzero(sent_id == eos_token_id).view(-1)
        t = eos[0] if len(eos) > 0 else len(sent_id)
        return [self.TRG.vocab.itos[j] for j in sent_id[1 : t]]

    @torch.no_grad()
    def beam_search(self, src, src_lang=None, trg_lang=None, src_tokens=None, n_best=1, debug=False):
        """
        Beam search select k words with the highest conditional probability
         to be the first word of the k candidate output sequences.
        Args:
            src: The batch of sentences, already in [batch_size, tokens] of int
            src_tokens: src in str version, same size as above
            n_best: number of usable values per beam loaded (Not implemented)
            debug: if true, print some debug information during the search
        Return: 
            An array of translated sentences, in list-of-tokens format. TODO convert [batch_size, n_best, tgt_len] instead of [batch_size, tgt_len]
        """
        # Create some local variable
        src_field, trg_field = self.SRC, self.TRG
        sos_token = generate_language_token(trg_lang) if trg_lang is not None else const.DEFAULT_SOS
        init_token = trg_field.vocab.stoi[sos_token]
        eos_token_id = trg_field.vocab.stoi[const.DEFAULT_EOS]
        src = src.to(self.device)
        
        batch_size = src.size(0)
        model = self.model
        k = self.beam_size
        max_len = self.max_len
        device = self.device

        # Initialize target sequence (start with '<sos>' token) [batch_size x k x max_len]
        trg = torch.zeros(batch_size, k, max_len, device=device).long()
        trg[:, :, 0] = init_token

        # Precalc output from model's encoder 
        single_src_mask = (src != src_field.vocab.stoi['<pad>']).unsqueeze(1).to(device)
        e_out = model.encoder(src, single_src_mask) # [batch_size x S x d_model]

        # Output model prob
        trg_mask = generate_subsequent_mask(1, device=device)
        # [batch_size x 1]
        inp_decoder = trg[:, 0, 0].view(batch_size, 1)
        # [batch_size x 1 x vocab_size]
        prob = model.out(model.decoder(inp_decoder, e_out, single_src_mask, trg_mask))
        prob = functional.softmax(prob, dim=-1)
        
        # [batch_size x 1 x k]
        k_prob, k_index = torch.topk(prob, k, dim=-1)
        trg[:, :, 1] = k_index.view(batch_size, k)
        # Init log scores from k beams [batch_size x k x 1]
        log_scores = torch.log(k_prob.view(batch_size, k, 1))
        
        # Repeat encoder's output k times for searching [(k * batch_size) x S x d_model]
        e_outs = torch.repeat_interleave(e_out, k, dim=0)
        src_mask = torch.repeat_interleave(single_src_mask, k, dim=0)

        # Create mask for checking eos
        sent_eos = torch.tensor([eos_token_id for _ in range(k)], device=device).view(1, k)

        # The batch indexes
        batch_index = torch.arange(batch_size)
        finished_batches = torch.zeros(batch_size, device=device).long()

        # Iteratively searching
        for i in range(2, max_len):
            trg_mask = generate_subsequent_mask(i, device)
            
            # Flatten trg tensor for feeding into model [(k * batch_size) x i]
            inp_decoder = trg[batch_index, :, :i].view(k * len(batch_index), i)
            # Output model prob [(k * batch_size) x i x vocab_size]
            prob = model.out(model.decoder(inp_decoder, e_outs, src_mask, trg_mask))
            prob = functional.softmax(prob, dim=-1)

            # Only care the last prob i-th
            # [(k * batch_size) x 1 x vocab_size]
            prob = prob[:, i-1, :].view(k * len(batch_index), 1, -1)

            # Truncate prob to top k [(k * batch_size) x 1 x k]
            k_prob, k_index = prob.data.topk(k, dim=-1)

            # Deflatten k_prob & k_index
            k_prob = k_prob.view(len(batch_index), k, 1, k)
            k_index = k_index.view(len(batch_index), k, 1, k)

            # Preserve eos beams
            # [batch_size x k] -> view -> [batch_size x k x 1 x 1] (broadcastable)
            eos_mask = (trg[batch_index, :, i-1] == eos_token_id).view(len(batch_index), k, 1, 1)
            k_prob.masked_fill_(eos_mask, 1.0)
            k_index.masked_fill_(eos_mask, eos_token_id)

            # Find the best k cases
            # Compute log score at i-th timestep 
            # [batch_size x k x 1 x 1] + [batch_size x k x 1 x k] = [batch_size x k x 1 x k]
            combine_probs = log_scores[batch_index].unsqueeze(-1) + torch.log(k_prob) 
            # [batch_size x k x 1]
            log_scores[batch_index], positions = torch.topk(combine_probs.view(len(batch_index), k * k, 1), k, dim=1)

            # The rows selected from top k
            rows = positions.view(len(batch_index), k) // k
            # The indexes in vocab respected to these rows
            cols = positions.view(len(batch_index), k) % k
            
            batch_sim = torch.arange(len(batch_index)).view(-1, 1)
            trg[batch_index, :, :] = trg[batch_index.view(-1, 1), rows, :]
            trg[batch_index, :, i] = k_index[batch_sim, rows, :, cols].view(len(batch_index), k)
            
            # Update which sentences finished all its beams
            mask = (trg[:, :, i] == sent_eos).all(1).view(-1).to(device)
            finished_batches.masked_fill_(mask, value=1)    
            cnt = torch.sum(finished_batches).item()
            if cnt == batch_size:
                break
            
            # Continue with remaining batches (if any)
            batch_index = torch.nonzero(finished_batches == 0).view(-1)
            e_outs = torch.repeat_interleave(e_out[batch_index], k, dim=0)
            src_mask = torch.repeat_interleave(single_src_mask[batch_index], k, dim=0)
            # End loop

        # Get the best beam
        log_scores = log_scores.view(batch_size, k)
        results = [self._convert_to_sent(trg[t, j.item(), :], eos_token_id) for t, j in enumerate(torch.argmax(log_scores, dim=-1))]
        return results