import torch.nn as nn import torch from .transformer import TransformerBlock from .embedding import BERTEmbedding class BERT(nn.Module): """ BERT model : Bidirectional Encoder Representations from Transformers. """ def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1): """ :param vocab_size: vocab_size of total words :param hidden: BERT model hidden size :param n_layers: numbers of Transformer blocks(layers) :param attn_heads: number of attention heads :param dropout: dropout rate """ super().__init__() self.hidden = hidden self.n_layers = n_layers self.attn_heads = attn_heads # paper noted they used 4*hidden_size for ff_network_hidden_size self.feed_forward_hidden = hidden * 4 # embedding for BERT, sum of positional, segment, token embeddings self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden) # multi-layers transformer blocks, deep network self.transformer_blocks = nn.ModuleList( [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)]) # self.attention_values = [] def forward(self, x, segment_info): # attention masking for padded token # torch.ByteTensor([batch_size, 1, seq_len, seq_len) device = x.device masked = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1) r,e,c = masked.shape mask = torch.zeros((r, e, c), dtype=torch.bool).to(device=device) for i in range(r): mask[i] = masked[i].T*masked[i] mask = mask.unsqueeze(1) # mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) # print("bert mask: ", mask) # embedding the indexed sequence to sequence of vectors x = self.embedding(x, segment_info) # self.attention_values = [] # running over multiple transformer blocks for transformer in self.transformer_blocks: x = transformer.forward(x, mask) # self.attention_values.append(transformer.p_attn) return x