File size: 4,915 Bytes
8cb4f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
##@title Beam của mình
import numpy as np
import torch
import math
import torch.nn.functional as functional
import torch.nn as nn
from torch.autograd import Variable

from modules.inference.decode_strategy import DecodeStrategy
from utils.misc import no_peeking_mask

class GreedySearch(DecodeStrategy):
    def __init__(self, model, max_len, device, replace_unk=True):
        """
        :param beam_size
        :param batch_size
        :param beam_offset
        """
        super(GreedySearch, self).__init__(model, max_len, device)
        self.batch_size = batch_size
        self.replace_unk = replace_unk
        raise NotImplementedError("Replace unk was yeeted from base class DecodeStrategy. Fix first.")

    def initilize_value(self, sentences):
        """
        Calculate the required matrices during translation after the model is finished
        Input:
        :param src: Sentences

        Output: Initialize the first character includes outputs, e_outputs, log_scores
        """

        init_tok = self.TRG.vocab.stoi['<sos>']
        src_mask = (sentences != self.SRC.vocab.stoi['<pad>']).unsqueeze(-2)
        eos_tok = self.TRG.vocab.stoi['<eos>']

        # Encoder
        e_output = self.model.encoder(sentences, src_mask)

        out = torch.LongTensor([[init_tok] for i in range(self.batch_size)])
        outputs = torch.zeros(self.batch_size, self.max_len).long()
        outputs[:, :1] = out

        outputs = outputs.to(self.device)
        is_finished = torch.LongTensor([[eos_tok] for i in range(self.batch_size)]).view(-1).to(self.device)
        return eos_tok, src_mask, is_finished, e_output, outputs

    def create_trg_mask(self, i, device):
        return no_peeking_mask(i, device)

    def current_predict(self, outputs, e_output, src_mask, trg_mask):
        out, attn = self.model.out(self.model.decoder(outputs,
                                  e_output, src_mask, trg_mask))
        out = functional.softmax(out, dim=-1)
        return out, attn

    def greedy_search(self, sentences, sampling_temp=0.0, keep_topk=1):
        if len(sentences) < self.batch_size:
            self.batch_size = len(sentences)

        eos_tok, src_mask, is_finished, e_output, outputs = self.initilize_value(sentences)

        for i in range(1, self.max_len):
            out, attn = self.current_predict(outputs[:, :i], e_output, src_mask, self.create_trg_mask(i, self.device))
            topk_ix, topk_prob = self.sample_with_temperature(out[:, -1], sampling_temp, keep_topk)
            outputs[:, i] = topk_ix.view(-1)
            if torch.equal(outputs[:, i], is_finished):
                break
        
        if self.replace_unk == True:
            outputs = self.replace_unknown(attn, sentences, outputs)

        # print("\n".join([' '.join([self.TRG.vocab.itos[tok] for tok in line[1:]]) for line in outputs]))
        # Write to file or Print to the console
        translated_sentences = []
        # Get the best sentences: idx = 0 + i*k
        for i in range(0, len(outputs)):
            is_eos = torch.nonzero(outputs[i]==eos_tok)
            if len(is_eos) == 0:
                # if there is no sequence end, remove
                sent = outputs[i, 1:]
            else:
                length = is_eos[0]
                sent = outputs[i, 1:length]
            translated_sentences.append([self.TRG.vocab.itos[tok] for tok in sent])

        return translated_sentences

    def sample_with_temperature(self, logits, sampling_temp, keep_topk):
        if sampling_temp == 0.0 or keep_topk == 1:
            # For temp=0.0, take the argmax to avoid divide-by-zero errors.
            # keep_topk=1 is also equivalent to argmax.
            topk_scores, topk_ids = logits.topk(1, dim=-1)
            if sampling_temp > 0:
                topk_scores /= sampling_temp
        else:
            logits = torch.div(logits, sampling_temp)

            if keep_topk > 0:
                top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
                kth_best = top_values[:, -1].view([-1, 1])
                kth_best = kth_best.repeat([1, logits.shape[1]]).float()

                # Set all logits that are not in the top-k to -10000.
                # This puts the probabilities close to 0.
                ignore = torch.lt(logits, kth_best)
                logits = logits.masked_fill(ignore, -10000)

            dist = torch.distributions.Multinomial(
                logits=logits, total_count=1)
            topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
            topk_scores = logits.gather(dim=1, index=topk_ids)
        return topk_ids, topk_scores

    def translate_batch(self, sentences):
        # super(BeamSearch, self).__init__()
        sentences = self.preprocess_batch(sentences).to(self.device)
        return self.greedy_search(sentences, 0.2, 2)
        # print(self.initilize_value(sentences))