File size: 2,094 Bytes
6f408cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import math



# Source: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


"""
    Same scheduler as in "Attention Is All You Need"
"""
class NoamScheduler():
    def __init__(self, optimizer, warmup, model_size):
        self.epoch = 0
        self.optimizer = optimizer
        self.warmup = warmup
        self.model_size = model_size

    def step(self):
        self.epoch += 1
        new_lr = self.model_size**(-0.5) * min(self.epoch**(-0.5), self.epoch * self.warmup**(-1.5))

        for param in self.optimizer.param_groups:
            param["lr"] = new_lr


"""
    Encoders to attend sentence level features.
"""
class TransformerInterEncoder(nn.Module):
    def __init__(self, d_model, d_ff=2048, nheads=8, num_encoders=2, dropout=0.1, max_len=512):
        super().__init__()
        self.positional_enc = PositionalEncoding(d_model, dropout, max_len)
        self.encoders = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nheads, dim_feedforward=d_ff), 
            num_layers=num_encoders
        )
        self.layer_norm = nn.LayerNorm(d_model)
        self.linear = nn.Linear(d_model, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.positional_enc(x)
        x = self.encoders(x)
        x = self.layer_norm(x)
        logit = self.linear(x)
        sentences_scores = self.sigmoid(logit)

        return sentences_scores.squeeze(-1), logit.squeeze(-1)