File size: 5,952 Bytes
8cb4f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as functional
import math
import logging

class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_seq_length=200, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        self._max_seq_length = max_seq_length
        
        pe = torch.zeros(max_seq_length, d_model)
        
        for pos in range(max_seq_length):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos/(10000**(2*i/d_model)))
                pe[pos, i+1] = math.cos(pos/(10000**((2*i+1)/d_model)))
        pe = pe.unsqueeze(0)        
        self.register_buffer('pe', pe)

        @torch.jit.script
        def splice_by_size(source, target):
            """Custom function to splice the source by target's second dimension. Required due to torch.Size not a torchTensor. Why? hell if I know."""
            length = target.size(1);
            return source[:, :length]

        self.splice_by_size = splice_by_size
    
    def forward(self, x):
        if(x.shape[1] > self._max_seq_length):
            logging.warn("Input longer than maximum supported length for PE detected. Build a model with a larger input_max_length limit if you want to keep the input; or ignore if you want the input trimmed")
            x = x[:, :self._max_seq_length]
        
        x = x * math.sqrt(self.d_model)
        
        spliced_pe = self.splice_by_size(self.pe, x) # self.pe[:, :x.shape[1]]
#        pe = Variable(spliced_pe, requires_grad=False)
        pe = spliced_pe.requires_grad_(False)
        
#        if x.is_cuda: # remove since it is a sub nn.Module
#            pe.cuda()
#        assert all([xs == ys for xs, ys in zip(x.shape[1:], pe.shape[1:])]), "{} - {}".format(x.shape, pe.shape)

        x = x + pe
        x = self.dropout(x)
        
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()
        assert d_model % heads == 0
        
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads

        # three casting linear layer for query/key.value
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, q, k, v, mask=None):
        """
        Args:
            q / k / v: query/key/value, should all be [batch_size, sequence_length, d_model]. Only differ in decode attention, where q is tgt_len and k/v is src_len
            mask: either [batch_size, 1, src_len] or [batch_size, tgt_len, tgt_len]. The last two dimensions must match or are broadcastable.
        Returns:
            the value of the attention process, [batch_size, sequence_length, d_model].
            The used attention, [batch_size, q_length, k_v_length]
        """
        bs = q.shape[0]
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        value, attn = self.attention(q, k, v, mask, self.dropout)
        concat = value.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output, attn

    def attention(self, q, k, v, mask=None, dropout=None):
        """Calculate the attention and output the attention & value
        Args:
            q / k / v: query/key/value already transformed, should all be [batch_size, heads, sequence_length, d_k]. Only differ in decode attention, where q is tgt_len and k/v is src_len
            mask: either [batch_size, 1, src_len] or [batch_size, tgt_len, tgt_len]. The last two dimensions must match or are broadcastable.
        Returns: 
            the attentionized but raw values [batch_size, head, seq_length, d_k]
            the attention calculated [batch_size, heads, sequence_length, sequence_length]
        """
    
#        d_k = q.shape[-1]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            mask = mask.unsqueeze(1) # add a dimension to account for head
            scores = scores.masked_fill(mask==0, -1e9)
        # softmax the padding/peeking masked attention
        scores = functional.softmax(scores, dim=-1)
        
        if dropout is not None:
            scores = dropout(scores)
        
        output = torch.matmul(scores, v)
        return output, scores

class Norm(nn.Module):
    def __init__(self, d_model, eps = 1e-6):
        super().__init__()
    
        self.size = d_model
        
        # create two learnable parameters to calibrate normalisation
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        
        self.eps = eps
    
    def forward(self, x):
        norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
        / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
        return norm

class FeedForward(nn.Module):
    """A two-hidden-linear feedforward layer that can activate and dropout its transition state"""
    def __init__(self, d_model, d_ff=2048, internal_activation=functional.relu, dropout=0.1):
        super().__init__() 
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)

        self.internal_activation = internal_activation
    
    def forward(self, x):
        x = self.dropout(self.internal_activation(self.linear_1(x)))
        x = self.linear_2(x)
        return x