NMT-LaVi / modules /inference /sampling_temperature.py
hieungo1410's picture
'add'
8cb4f3b
##@title Beam của mình
import numpy as np
import torch
import math
import torch.nn.functional as functional
import torch.nn as nn
from torch.autograd import Variable
from modules.inference.decode_strategy import DecodeStrategy
from utils.misc import no_peeking_mask
class GreedySearch(DecodeStrategy):
def __init__(self, model, max_len, device, replace_unk=True):
"""
:param beam_size
:param batch_size
:param beam_offset
"""
super(GreedySearch, self).__init__(model, max_len, device)
self.batch_size = batch_size
self.replace_unk = replace_unk
raise NotImplementedError("Replace unk was yeeted from base class DecodeStrategy. Fix first.")
def initilize_value(self, sentences):
"""
Calculate the required matrices during translation after the model is finished
Input:
:param src: Sentences
Output: Initialize the first character includes outputs, e_outputs, log_scores
"""
init_tok = self.TRG.vocab.stoi['<sos>']
src_mask = (sentences != self.SRC.vocab.stoi['<pad>']).unsqueeze(-2)
eos_tok = self.TRG.vocab.stoi['<eos>']
# Encoder
e_output = self.model.encoder(sentences, src_mask)
out = torch.LongTensor([[init_tok] for i in range(self.batch_size)])
outputs = torch.zeros(self.batch_size, self.max_len).long()
outputs[:, :1] = out
outputs = outputs.to(self.device)
is_finished = torch.LongTensor([[eos_tok] for i in range(self.batch_size)]).view(-1).to(self.device)
return eos_tok, src_mask, is_finished, e_output, outputs
def create_trg_mask(self, i, device):
return no_peeking_mask(i, device)
def current_predict(self, outputs, e_output, src_mask, trg_mask):
out, attn = self.model.out(self.model.decoder(outputs,
e_output, src_mask, trg_mask))
out = functional.softmax(out, dim=-1)
return out, attn
def greedy_search(self, sentences, sampling_temp=0.0, keep_topk=1):
if len(sentences) < self.batch_size:
self.batch_size = len(sentences)
eos_tok, src_mask, is_finished, e_output, outputs = self.initilize_value(sentences)
for i in range(1, self.max_len):
out, attn = self.current_predict(outputs[:, :i], e_output, src_mask, self.create_trg_mask(i, self.device))
topk_ix, topk_prob = self.sample_with_temperature(out[:, -1], sampling_temp, keep_topk)
outputs[:, i] = topk_ix.view(-1)
if torch.equal(outputs[:, i], is_finished):
break
if self.replace_unk == True:
outputs = self.replace_unknown(attn, sentences, outputs)
# print("\n".join([' '.join([self.TRG.vocab.itos[tok] for tok in line[1:]]) for line in outputs]))
# Write to file or Print to the console
translated_sentences = []
# Get the best sentences: idx = 0 + i*k
for i in range(0, len(outputs)):
is_eos = torch.nonzero(outputs[i]==eos_tok)
if len(is_eos) == 0:
# if there is no sequence end, remove
sent = outputs[i, 1:]
else:
length = is_eos[0]
sent = outputs[i, 1:length]
translated_sentences.append([self.TRG.vocab.itos[tok] for tok in sent])
return translated_sentences
def sample_with_temperature(self, logits, sampling_temp, keep_topk):
if sampling_temp == 0.0 or keep_topk == 1:
# For temp=0.0, take the argmax to avoid divide-by-zero errors.
# keep_topk=1 is also equivalent to argmax.
topk_scores, topk_ids = logits.topk(1, dim=-1)
if sampling_temp > 0:
topk_scores /= sampling_temp
else:
logits = torch.div(logits, sampling_temp)
if keep_topk > 0:
top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
kth_best = top_values[:, -1].view([-1, 1])
kth_best = kth_best.repeat([1, logits.shape[1]]).float()
# Set all logits that are not in the top-k to -10000.
# This puts the probabilities close to 0.
ignore = torch.lt(logits, kth_best)
logits = logits.masked_fill(ignore, -10000)
dist = torch.distributions.Multinomial(
logits=logits, total_count=1)
topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
topk_scores = logits.gather(dim=1, index=topk_ids)
return topk_ids, topk_scores
def translate_batch(self, sentences):
# super(BeamSearch, self).__init__()
sentences = self.preprocess_batch(sentences).to(self.device)
return self.greedy_search(sentences, 0.2, 2)
# print(self.initilize_value(sentences))