File size: 2,208 Bytes
6a34fd4 5c72fe4 6a34fd4 5c72fe4 6a34fd4 5c72fe4 6a34fd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch.nn as nn
import torch
from .transformer import TransformerBlock
from .embedding import BERTEmbedding
class BERT(nn.Module):
"""
BERT model : Bidirectional Encoder Representations from Transformers.
"""
def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
"""
:param vocab_size: vocab_size of total words
:param hidden: BERT model hidden size
:param n_layers: numbers of Transformer blocks(layers)
:param attn_heads: number of attention heads
:param dropout: dropout rate
"""
super().__init__()
self.hidden = hidden
self.n_layers = n_layers
self.attn_heads = attn_heads
# paper noted they used 4*hidden_size for ff_network_hidden_size
self.feed_forward_hidden = hidden * 4
# embedding for BERT, sum of positional, segment, token embeddings
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
# multi-layers transformer blocks, deep network
self.transformer_blocks = nn.ModuleList(
[TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
# self.attention_values = []
def forward(self, x, segment_info):
# attention masking for padded token
# torch.ByteTensor([batch_size, 1, seq_len, seq_len)
device = x.device
masked = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1)
r,e,c = masked.shape
mask = torch.zeros((r, e, c), dtype=torch.bool).to(device=device)
for i in range(r):
mask[i] = masked[i].T*masked[i]
mask = mask.unsqueeze(1)
# mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
# print("bert mask: ", mask)
# embedding the indexed sequence to sequence of vectors
x = self.embedding(x, segment_info)
# self.attention_values = []
# running over multiple transformer blocks
for transformer in self.transformer_blocks:
x = transformer.forward(x, mask)
# self.attention_values.append(transformer.p_attn)
return x
|