|
import torch.nn as nn |
|
import torch |
|
|
|
from .transformer import TransformerBlock |
|
from .embedding import BERTEmbedding |
|
|
|
class BERT(nn.Module): |
|
""" |
|
BERT model : Bidirectional Encoder Representations from Transformers. |
|
""" |
|
|
|
def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1): |
|
""" |
|
:param vocab_size: vocab_size of total words |
|
:param hidden: BERT model hidden size |
|
:param n_layers: numbers of Transformer blocks(layers) |
|
:param attn_heads: number of attention heads |
|
:param dropout: dropout rate |
|
""" |
|
|
|
super().__init__() |
|
self.hidden = hidden |
|
self.n_layers = n_layers |
|
self.attn_heads = attn_heads |
|
|
|
|
|
self.feed_forward_hidden = hidden * 4 |
|
|
|
|
|
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden) |
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)]) |
|
|
|
|
|
def forward(self, x, segment_info): |
|
|
|
|
|
|
|
device = x.device |
|
|
|
masked = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1) |
|
r,e,c = masked.shape |
|
mask = torch.zeros((r, e, c), dtype=torch.bool).to(device=device) |
|
|
|
for i in range(r): |
|
mask[i] = masked[i].T*masked[i] |
|
mask = mask.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
x = self.embedding(x, segment_info) |
|
|
|
|
|
|
|
for transformer in self.transformer_blocks: |
|
x = transformer.forward(x, mask) |
|
|
|
|
|
return x |
|
|