astra / src /bert.py
suryadev1's picture
removed head
1922da0
raw
history blame
2.21 kB
import torch.nn as nn
import torch
from .transformer import TransformerBlock
from .embedding import BERTEmbedding
class BERT(nn.Module):
"""
BERT model : Bidirectional Encoder Representations from Transformers.
"""
def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
"""
:param vocab_size: vocab_size of total words
:param hidden: BERT model hidden size
:param n_layers: numbers of Transformer blocks(layers)
:param attn_heads: number of attention heads
:param dropout: dropout rate
"""
super().__init__()
self.hidden = hidden
self.n_layers = n_layers
self.attn_heads = attn_heads
# paper noted they used 4*hidden_size for ff_network_hidden_size
self.feed_forward_hidden = hidden * 4
# embedding for BERT, sum of positional, segment, token embeddings
self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
# multi-layers transformer blocks, deep network
self.transformer_blocks = nn.ModuleList(
[TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
# self.attention_values = []
def forward(self, x, segment_info):
# attention masking for padded token
# torch.ByteTensor([batch_size, 1, seq_len, seq_len)
device = x.device
masked = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1)
r,e,c = masked.shape
mask = torch.zeros((r, e, c), dtype=torch.bool).to(device=device)
for i in range(r):
mask[i] = masked[i].T*masked[i]
mask = mask.unsqueeze(1)
# mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
# print("bert mask: ", mask)
# embedding the indexed sequence to sequence of vectors
x = self.embedding(x, segment_info)
# self.attention_values = []
# running over multiple transformer blocks
for transformer in self.transformer_blocks:
x = transformer.forward(x, mask)
# self.attention_values.append(transformer.p_attn)
return x