|
import torch.nn as nn |
|
import torch |
|
import math |
|
|
|
class GELU(nn.Module): |
|
""" |
|
Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU |
|
""" |
|
|
|
def forward(self, x): |
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
|
|
class PositionwiseFeedForward(nn.Module): |
|
"Implements FFN equation." |
|
|
|
def __init__(self, d_model, d_ff, dropout=0.1): |
|
super().__init__() |
|
self.w_1 = nn.Linear(d_model, d_ff) |
|
self.w_2 = nn.Linear(d_ff, d_model) |
|
self.dropout = nn.Dropout(dropout) |
|
self.activation = GELU() |
|
|
|
def forward(self, x): |
|
return self.w_2(self.dropout(self.activation(self.w_1(x)))) |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
"Construct a layernorm module (See citation for details)." |
|
|
|
def __init__(self, features, eps=1e-6): |
|
super().__init__() |
|
self.a_2 = nn.Parameter(torch.ones(features)) |
|
self.b_2 = nn.Parameter(torch.zeros(features)) |
|
self.eps = eps |
|
|
|
def forward(self, x): |
|
mean = x.mean(-1, keepdim=True) |
|
std = x.std(-1, keepdim=True) |
|
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 |
|
|
|
|
|
class SublayerConnection(nn.Module): |
|
""" |
|
A residual connection followed by a layer norm. |
|
Note for code simplicity the norm is first as opposed to last. |
|
""" |
|
|
|
def __init__(self, size, dropout): |
|
super().__init__() |
|
self.norm = LayerNorm(size) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x, sublayer): |
|
"Apply residual connection to any sublayer with the same size." |
|
return x + self.dropout(sublayer(self.norm(x))) |
|
|