|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Implements Mosaic BERT, with an eye towards the Hugging Face API. |
|
|
|
Mosaic BERT improves performance over Hugging Face BERT through the following: |
|
|
|
1. ALiBi. This architectural change removes positional embeddings and instead encodes positional |
|
information through attention biases based on query-key position distance. It improves the effectiveness |
|
of training with shorter sequence lengths by enabling extrapolation to longer sequences. |
|
|
|
2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer |
|
to improve overall expressiveness, providing better convergence properties. |
|
|
|
3. Flash Attention. The Mosaic BERT's self-attention layer makes use of Flash Attention, which dramatically |
|
improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that |
|
supports attention biases, which allows us to use Flash Attention with ALiBi. |
|
|
|
4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT |
|
implementations waste computation on padded tokens. Mosaic BERT internally unpads to reduce unnecessary computation |
|
and improve speed. It does this without changing how the user interfaces with the model, thereby |
|
preserving the simple API of standard implementations. |
|
|
|
|
|
Currently, Mosaic BERT is available for masked language modeling :class:`BertForMaskedLM` and sequence |
|
classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases. |
|
|
|
See :file:`./mosaic_bert.py` for utilities to simplify working with Mosaic BERT in Composer, and for example usage |
|
of the core Mosaic BERT classes. |
|
""" |
|
|
|
import copy |
|
import logging |
|
import math |
|
import os |
|
import sys |
|
import warnings |
|
from typing import List, Optional, Tuple, Union |
|
from .configuration_bert import BertConfig |
|
|
|
sys.path.append(os.path.dirname(os.path.realpath(__file__))) |
|
|
|
from .bert_padding import (index_first_axis, |
|
index_put_first_axis, pad_input, |
|
unpad_input, unpad_input_only) |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
from einops import rearrange |
|
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present |
|
from transformers.activations import ACT2FN |
|
from transformers.modeling_outputs import (MaskedLMOutput, |
|
SequenceClassifierOutput) |
|
from transformers.models.bert.modeling_bert import BertPreTrainedModel |
|
logger = logging.getLogger(__name__) |
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6): |
|
""" |
|
RMSNorm is equivalent to T5LayerNorm |
|
""" |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, hidden_states): |
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
|
|
|
if self.weight.dtype in [torch.float16, torch.bfloat16]: |
|
hidden_states = hidden_states.to(self.weight.dtype) |
|
|
|
return self.weight * hidden_states |
|
|
|
class RotaryEmbedding(torch.nn.Module): |
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): |
|
super().__init__() |
|
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) |
|
self.max_seq_len_cached = max_position_embeddings |
|
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32) |
|
freqs = torch.outer(t, self.inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32) |
|
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32) |
|
def forward(self, x, seq_len=None): |
|
|
|
|
|
if seq_len > self.max_seq_len_cached: |
|
self.max_seq_len_cached = seq_len |
|
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32) |
|
freqs = torch.outer(t, self.inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device) |
|
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device) |
|
elif self.cos_cached.device != x.device: |
|
self.cos_cached = self.cos_cached.to(x.device) |
|
self.sin_cached = self.sin_cached.to(x.device) |
|
return ( |
|
self.cos_cached[:, :, :seq_len, ...], |
|
self.sin_cached[:, :, :seq_len, ...], |
|
) |
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2:] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos_, sin_): |
|
|
|
|
|
cos = torch.repeat_interleave(cos_[:, :, None, :], q.shape[0], 0).squeeze(1) |
|
sin = torch.repeat_interleave(sin_[:, :, None, :], q.shape[0], 0).squeeze(1) |
|
|
|
|
|
|
|
q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin) |
|
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin) |
|
return q_embed.to(q.dtype), k_embed.to(k.dtype) |
|
|
|
class BertEmbeddings(nn.Module): |
|
"""Construct the embeddings for words, ignoring position. |
|
|
|
There are no positional embeddings since we use ALiBi and token_type |
|
embeddings. |
|
|
|
This module is modeled after the Hugging Face BERT's |
|
:class:`~transformers.model.bert.modeling_bert.BertEmbeddings`, but is |
|
modified as part of Mosaic BERT's ALiBi implementation. The key change is |
|
that position embeddings are removed. Position information instead comes |
|
from attention biases that scale linearly with the position distance |
|
between query and key tokens. |
|
|
|
This module ignores the `position_ids` input to the `forward` method. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.word_embeddings = nn.Embedding(config.vocab_size, |
|
config.hidden_size, |
|
padding_idx=config.pad_token_id) |
|
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, |
|
config.hidden_size) |
|
|
|
|
|
|
|
self.norm = RMSNorm(config.hidden_size, |
|
eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.register_buffer('token_type_ids', |
|
torch.zeros(config.max_position_embeddings, |
|
dtype=torch.long), |
|
persistent=False) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
past_key_values_length: int = 0, |
|
) -> torch.Tensor: |
|
if (input_ids is not None) == (inputs_embeds is not None): |
|
raise ValueError('Must specify either input_ids or input_embeds!') |
|
if input_ids is not None: |
|
input_shape = input_ids.size() |
|
else: |
|
assert inputs_embeds is not None |
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
seq_length = input_shape[1] |
|
|
|
if position_ids is None: |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
if token_type_ids is None: |
|
if hasattr(self, 'token_type_ids'): |
|
assert isinstance(self.token_type_ids, torch.LongTensor) |
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length] |
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand( |
|
input_shape[0], seq_length) |
|
token_type_ids = buffered_token_type_ids_expanded |
|
else: |
|
token_type_ids = torch.zeros(input_shape, |
|
dtype=torch.long, |
|
device=self.word_embeddings.device) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.word_embeddings(input_ids) |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
embeddings = inputs_embeds + token_type_embeddings |
|
|
|
embeddings = self.norm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class BertUnpadSelfAttention(nn.Module): |
|
"""Performs multi-headed self attention on a batch of unpadded sequences. |
|
|
|
If Triton is installed, this module uses Flash Attention to greatly improve throughput. |
|
The Flash Attention implementation used in Mosaic BERT supports arbitrary attention biases (which |
|
we use to implement ALiBi), but does not support attention dropout. If either Triton is not installed |
|
or `config.attention_probs_dropout_prob > 0`, the implementation will default to a |
|
math-equivalent pytorch version, which is much slower. |
|
|
|
See `forward` method for additional detail. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr( |
|
config, 'embedding_size'): |
|
raise ValueError( |
|
f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' |
|
f'heads ({config.num_attention_heads})') |
|
|
|
self.num_attention_heads = config.num_attention_heads |
|
self.attention_head_size = int(config.hidden_size / |
|
config.num_attention_heads) |
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
self.p_dropout = config.attention_probs_dropout_prob |
|
self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size) |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.rotary_emb = RotaryEmbedding(self.attention_head_size, max_position_embeddings=self.max_position_embeddings) |
|
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, |
|
max_seqlen_in_batch: int, indices: torch.Tensor, |
|
attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: |
|
"""Perform self-attention. |
|
|
|
If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch |
|
implementation of self-attention. |
|
|
|
The arguments are unpadded, and our implementations of attention require padded arguments, |
|
so we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers. |
|
The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute. |
|
It is possible to write an unpadded implementation of attention (in Triton and PyTorch), which we will eventually do. |
|
|
|
Args: |
|
hidden_states: (total_nnz, dim) |
|
cu_seqlens: (batch + 1,) |
|
max_seqlen_in_batch: int |
|
indices: (total_nnz,) |
|
attn_mask: (batch, max_seqlen_in_batch) |
|
bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) |
|
|
|
Returns: |
|
attention: (total_nnz, dim) |
|
""" |
|
qkv = self.Wqkv(hidden_states) |
|
qkv = pad_input( |
|
qkv, indices, cu_seqlens.shape[0] - 1, |
|
max_seqlen_in_batch) |
|
qkv = rearrange(qkv, |
|
'b s (t h d) -> b s t h d', |
|
t=3, |
|
h=self.num_attention_heads) |
|
|
|
q = qkv[:, :, 0, :, :].transpose(1, 2) |
|
k = qkv[:, :, 1, :, :].transpose(1, 2) |
|
v = qkv[:, :, 2, :, :].transpose(1, 2) |
|
kv_seq_len = k.shape[-2] |
|
|
|
cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) |
|
q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
|
k = k.permute(0, 1, 3, 2) |
|
|
|
|
|
|
|
|
|
|
|
attention_scores = torch.matmul(q, k) / math.sqrt( |
|
self.attention_head_size) |
|
attention_scores = attention_scores + bias |
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
|
attention_probs = self.dropout(attention_probs) |
|
attention = torch.matmul(attention_probs, v).permute(0, 2, 1, |
|
3) |
|
|
|
|
|
attention = unpad_input_only( |
|
attention, |
|
torch.squeeze(attn_mask) == 1) |
|
return rearrange(attention, 'nnz h d -> nnz (h d)') |
|
|
|
|
|
|
|
class BertSelfOutput(nn.Module): |
|
"""Computes the output of the attention layer. |
|
|
|
This module is modeled after the Hugging Face BERT's |
|
:class:`~transformers.model.bert.modeling_bert.BertSelfOutput`. |
|
The implementation is identical. Rather than use the original module |
|
directly, we re-implement it here so that Mosaic BERT's modules will not |
|
be affected by any Composer surgery algorithm that modifies Hugging Face |
|
BERT modules. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.norm = RMSNorm(config.hidden_size, |
|
eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward(self, hidden_states: torch.Tensor, |
|
input_tensor: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
hidden_states = self.norm(hidden_states + input_tensor) |
|
return hidden_states |
|
|
|
|
|
class BertUnpadAttention(nn.Module): |
|
"""Chains attention, Dropout, and LayerNorm for Mosaic BERT.""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.self = BertUnpadSelfAttention(config) |
|
self.output = BertSelfOutput(config) |
|
|
|
def forward( |
|
self, |
|
input_tensor: torch.Tensor, |
|
cu_seqlens: torch.Tensor, |
|
max_s: int, |
|
subset_idx: Optional[torch.Tensor] = None, |
|
indices: Optional[torch.Tensor] = None, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
bias: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for scaled self-attention without padding. |
|
|
|
Arguments: |
|
input_tensor: (total_nnz, dim) |
|
cu_seqlens: (batch + 1,) |
|
max_s: int |
|
subset_idx: () set of indices whose values we care about at the end of the layer |
|
(e.g., the masked tokens, if this is the final layer). |
|
indices: None or (total_nnz,) |
|
attn_mask: None or (batch, max_seqlen_in_batch) |
|
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) |
|
""" |
|
self_output = self.self(input_tensor, cu_seqlens, max_s, indices, |
|
attn_mask, bias) |
|
if subset_idx is not None: |
|
return self.output( |
|
index_first_axis(self_output, subset_idx), |
|
index_first_axis(input_tensor, subset_idx)) |
|
else: |
|
return self.output(self_output, input_tensor) |
|
|
|
class MLP(nn.Module): |
|
def __init__( |
|
self, |
|
config |
|
): |
|
super().__init__() |
|
self.config = config |
|
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) |
|
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
|
self.act_fn = ACT2FN[config.hidden_act] |
|
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
def forward(self, hidden_states): |
|
residual_connection = hidden_states |
|
hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) |
|
hidden_states = self.norm(hidden_states + residual_connection) |
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertLayer(nn.Module): |
|
"""Composes the Mosaic BERT attention and FFN blocks into a single layer.""" |
|
|
|
def __init__(self, config): |
|
super(BertLayer, self).__init__() |
|
self.attention = BertUnpadAttention(config) |
|
self.mlp = MLP(config) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
cu_seqlens: torch.Tensor, |
|
seqlen: int, |
|
subset_idx: Optional[torch.Tensor] = None, |
|
indices: Optional[torch.Tensor] = None, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
bias: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for a BERT layer, including both attention and MLP. |
|
|
|
Args: |
|
hidden_states: (total_nnz, dim) |
|
cu_seqlens: (batch + 1,) |
|
seqlen: int |
|
subset_idx: () set of indices whose values we care about at the end of the layer |
|
(e.g., the masked tokens, if this is the final layer). |
|
indices: None or (total_nnz,) |
|
attn_mask: None or (batch, max_seqlen_in_batch) |
|
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) |
|
""" |
|
attention_output = self.attention(hidden_states, cu_seqlens, seqlen, |
|
subset_idx, indices, attn_mask, bias) |
|
layer_output = self.mlp(attention_output) |
|
return layer_output |
|
|
|
|
|
class BertEncoder(nn.Module): |
|
"""A stack of BERT layers providing the backbone of Mosaic BERT. |
|
|
|
This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertEncoder`, |
|
but with substantial modifications to implement unpadding and ALiBi. |
|
|
|
Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation |
|
at padded tokens, and pre-computes attention biases to implement ALiBi. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
layer = BertLayer(config) |
|
self.layer = nn.ModuleList( |
|
[copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) |
|
|
|
self.num_attention_heads = config.num_attention_heads |
|
|
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
output_all_encoded_layers: Optional[bool] = True, |
|
subset_mask: Optional[torch.Tensor] = None, |
|
) -> List[torch.Tensor]: |
|
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
extended_attention_mask = extended_attention_mask.to( |
|
dtype=next(self.parameters()).dtype) |
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
|
|
attention_mask_bool = attention_mask.bool() |
|
batch, seqlen = hidden_states.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
hidden_states, indices, cu_seqlens, _ = unpad_input( |
|
hidden_states, attention_mask_bool) |
|
|
|
attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen] |
|
all_encoder_layers = [] |
|
if subset_mask is None: |
|
for layer_module in self.layer: |
|
hidden_states = layer_module(hidden_states, |
|
cu_seqlens, |
|
seqlen, |
|
None, |
|
indices, |
|
attn_mask=attention_mask, |
|
bias=attn_bias) |
|
if output_all_encoded_layers: |
|
all_encoder_layers.append(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = pad_input( |
|
hidden_states, indices, batch, seqlen) |
|
else: |
|
for i in range(len(self.layer) - 1): |
|
layer_module = self.layer[i] |
|
hidden_states = layer_module(hidden_states, |
|
cu_seqlens, |
|
seqlen, |
|
None, |
|
indices, |
|
attn_mask=attention_mask, |
|
bias=attn_bias) |
|
if output_all_encoded_layers: |
|
all_encoder_layers.append(hidden_states) |
|
subset_idx = torch.nonzero(subset_mask[attention_mask_bool], |
|
as_tuple=False).flatten() |
|
hidden_states = self.layer[-1](hidden_states, |
|
cu_seqlens, |
|
seqlen, |
|
subset_idx=subset_idx, |
|
indices=indices, |
|
attn_mask=attention_mask, |
|
bias=attn_bias) |
|
|
|
if not output_all_encoded_layers: |
|
all_encoder_layers.append(hidden_states) |
|
return all_encoder_layers |
|
|
|
|
|
class BertPooler(nn.Module): |
|
|
|
def __init__(self, config): |
|
super(BertPooler, self).__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.activation = nn.Tanh() |
|
|
|
def forward(self, |
|
hidden_states: torch.Tensor, |
|
pool: Optional[bool] = True) -> torch.Tensor: |
|
|
|
|
|
first_token_tensor = hidden_states[:, 0] if pool else hidden_states |
|
pooled_output = self.dense(first_token_tensor) |
|
pooled_output = self.activation(pooled_output) |
|
return pooled_output |
|
|
|
|
|
class BertPredictionHeadTransform(nn.Module): |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
if isinstance(config.hidden_act, str): |
|
self.transform_act_fn = ACT2FN[config.hidden_act] |
|
else: |
|
self.transform_act_fn = config.hidden_act |
|
self.norm = RMSNorm(config.hidden_size, eps=1e-12) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.transform_act_fn(hidden_states) |
|
hidden_states = self.norm(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class BertModel(BertPreTrainedModel): |
|
"""Overall BERT model. |
|
|
|
Args: |
|
config: a BertConfig class instance with the configuration to build a new model |
|
|
|
Inputs: |
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] |
|
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts |
|
`extract_features.py`, `run_classifier.py` and `run_squad.py`) |
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token |
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to |
|
a `sentence B` token (see BERT paper for more details). |
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices |
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max |
|
input sequence length in the current batch. It's the mask that we typically use for attention when |
|
a batch has varying length sentences. |
|
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. |
|
|
|
Outputs: Tuple of (encoded_layers, pooled_output) |
|
`encoded_layers`: controlled by `output_all_encoded_layers` argument: |
|
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end |
|
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each |
|
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], |
|
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding |
|
to the last attention block of shape [batch_size, sequence_length, hidden_size], |
|
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a |
|
classifier pretrained on top of the hidden state associated to the first character of the |
|
input (`CLS`) to train on the Next-Sentence task (see BERT's paper). |
|
|
|
Example usage: |
|
```python |
|
# Already been converted into WordPiece token ids |
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) |
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) |
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) |
|
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, |
|
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) |
|
model = BertModel(config=config) |
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) |
|
``` |
|
""" |
|
|
|
def __init__(self, config, add_pooling_layer=True): |
|
super(BertModel, self).__init__(config) |
|
self.embeddings = BertEmbeddings(config) |
|
self.encoder = BertEncoder(config) |
|
self.pooler = BertPooler(config) if add_pooling_layer else None |
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embeddings.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
self.embeddings.word_embeddings = value |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
output_all_encoded_layers: Optional[bool] = False, |
|
masked_tokens_mask: Optional[torch.Tensor] = None, |
|
**kwargs |
|
) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]: |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids) |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros_like(input_ids) |
|
|
|
embedding_output = self.embeddings(input_ids, token_type_ids, |
|
position_ids) |
|
|
|
subset_mask = [] |
|
first_col_mask = [] |
|
|
|
if masked_tokens_mask is None: |
|
subset_mask = None |
|
else: |
|
first_col_mask = torch.zeros_like(masked_tokens_mask) |
|
first_col_mask[:, 0] = True |
|
subset_mask = masked_tokens_mask | first_col_mask |
|
|
|
encoder_outputs = self.encoder( |
|
embedding_output, |
|
attention_mask, |
|
output_all_encoded_layers=output_all_encoded_layers, |
|
subset_mask=subset_mask) |
|
|
|
if masked_tokens_mask is None: |
|
sequence_output = encoder_outputs[-1] |
|
pooled_output = self.pooler( |
|
sequence_output) if self.pooler is not None else None |
|
else: |
|
|
|
attention_mask_bool = attention_mask.bool() |
|
subset_idx = subset_mask[attention_mask_bool] |
|
sequence_output = encoder_outputs[-1][ |
|
masked_tokens_mask[attention_mask_bool][subset_idx]] |
|
if self.pooler is not None: |
|
pool_input = encoder_outputs[-1][ |
|
first_col_mask[attention_mask_bool][subset_idx]] |
|
pooled_output = self.pooler(pool_input, pool=False) |
|
else: |
|
pooled_output = None |
|
|
|
if not output_all_encoded_layers: |
|
encoder_outputs = sequence_output |
|
|
|
if self.pooler is not None: |
|
return encoder_outputs, pooled_output |
|
|
|
return encoder_outputs, None |
|
|
|
|
|
|
|
|
|
|
|
class BertLMPredictionHead(nn.Module): |
|
|
|
def __init__(self, config, bert_model_embedding_weights): |
|
super().__init__() |
|
self.transform = BertPredictionHeadTransform(config) |
|
|
|
|
|
self.weight = nn.Parameter(torch.empty((bert_model_embedding_weights.size(0), bert_model_embedding_weights.size(1)))) |
|
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
|
self.first_flag = True |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.transform(hidden_states) |
|
if self.training: |
|
norm_weight = nn.functional.normalize(self.weight) |
|
self.first_flag = True |
|
elif self.first_flag: |
|
self.first_flag = False |
|
self.weight.data = nn.functional.normalize(self.weight) |
|
norm_weight = self.weight |
|
else: |
|
norm_weight = self.weight |
|
return nn.functional.linear(hidden_states, norm_weight) |
|
|
|
|
|
class BertOnlyMLMHead(nn.Module): |
|
|
|
def __init__(self, config, bert_model_embedding_weights): |
|
super().__init__() |
|
self.predictions = BertLMPredictionHead(config, |
|
bert_model_embedding_weights) |
|
|
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: |
|
prediction_scores = self.predictions(sequence_output) |
|
return prediction_scores |
|
|
|
|
|
class BertOnlyNSPHead(nn.Module): |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.seq_relationship = nn.Linear(config.hidden_size, 2) |
|
|
|
def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: |
|
seq_relationship_score = self.seq_relationship(pooled_output) |
|
return seq_relationship_score |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertForPreTraining(BertPreTrainedModel): |
|
|
|
pass |
|
|
|
|
|
class BertLMHeadModel(BertPreTrainedModel): |
|
|
|
pass |
|
|
|
|
|
class BertForMaskedLM(BertPreTrainedModel): |
|
config_class = BertConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
if config.is_decoder: |
|
warnings.warn( |
|
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for ' |
|
'bi-directional self-attention.') |
|
self.config = config |
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
self.cls = BertOnlyMLMHead(config, |
|
self.bert.embeddings.word_embeddings.weight) |
|
|
|
|
|
self.post_init() |
|
|
|
@classmethod |
|
def from_composer(cls, |
|
pretrained_checkpoint, |
|
state_dict=None, |
|
cache_dir=None, |
|
from_tf=False, |
|
config=None, |
|
*inputs, |
|
**kwargs): |
|
"""Load from pre-trained.""" |
|
model = cls(config, *inputs, **kwargs) |
|
if from_tf: |
|
raise ValueError( |
|
'Mosaic BERT does not support loading TensorFlow weights.') |
|
|
|
state_dict = torch.load(pretrained_checkpoint) |
|
|
|
consume_prefix_in_state_dict_if_present(state_dict, prefix='model.') |
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, |
|
strict=False) |
|
|
|
if len(missing_keys) > 0: |
|
logger.warning( |
|
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}" |
|
) |
|
if len(unexpected_keys) > 0: |
|
logger.warning( |
|
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}" |
|
) |
|
|
|
return model |
|
|
|
def get_output_embeddings(self): |
|
return self.cls.predictions.weight |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.cls.predictions.weight = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (input_ids is not None) == (inputs_embeds is not None): |
|
raise ValueError('Must specify either input_ids or input_embeds!') |
|
|
|
if labels is None: |
|
masked_tokens_mask = None |
|
else: |
|
masked_tokens_mask = labels > 0 |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
masked_tokens_mask=masked_tokens_mask, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
prediction_scores = self.cls(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
softmax_normalizer = prediction_scores.max(-1).values ** 2 |
|
z_loss_weight = 0.2 |
|
z_loss = z_loss_weight * softmax_normalizer.mean() |
|
|
|
masked_token_idx = torch.nonzero(labels.flatten() > 0, |
|
as_tuple=False).flatten() |
|
|
|
loss = loss_fct(prediction_scores, |
|
labels.flatten()[masked_token_idx]) + z_loss |
|
assert input_ids is not None, 'Coding error; please open an issue' |
|
batch, seqlen = input_ids.shape[:2] |
|
prediction_scores = rearrange( |
|
index_put_first_axis( |
|
prediction_scores, masked_token_idx, batch * seqlen), |
|
'(b s) d -> b s d', |
|
b=batch) |
|
|
|
if not return_dict: |
|
output = (prediction_scores,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return MaskedLMOutput( |
|
loss=loss, |
|
logits=prediction_scores, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
|
|
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
**model_kwargs): |
|
input_shape = input_ids.shape |
|
effective_batch_size = input_shape[0] |
|
|
|
|
|
if self.config.pad_token_id is None: |
|
raise ValueError('The PAD token should be defined for generation') |
|
|
|
attention_mask = torch.cat([ |
|
attention_mask, |
|
attention_mask.new_zeros((attention_mask.shape[0], 1)) |
|
], |
|
dim=-1) |
|
dummy_token = torch.full((effective_batch_size, 1), |
|
self.config.pad_token_id, |
|
dtype=torch.long, |
|
device=input_ids.device) |
|
input_ids = torch.cat([input_ids, dummy_token], dim=1) |
|
|
|
return {'input_ids': input_ids, 'attention_mask': attention_mask} |
|
|
|
|
|
class BertForNextSentencePrediction(BertPreTrainedModel): |
|
|
|
pass |
|
|
|
|
|
class BertForSequenceClassification(BertPreTrainedModel): |
|
"""Bert Model transformer with a sequence classification/regression head. |
|
|
|
This head is just a linear layer on top of the pooled output. Used for, |
|
e.g., GLUE tasks. |
|
""" |
|
config_class = BertConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.config = config |
|
|
|
self.bert = BertModel(config) |
|
classifier_dropout = (config.classifier_dropout |
|
if config.classifier_dropout is not None else |
|
config.hidden_dropout_prob) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
@classmethod |
|
def from_composer(cls, |
|
pretrained_checkpoint, |
|
state_dict=None, |
|
cache_dir=None, |
|
from_tf=False, |
|
config=None, |
|
*inputs, |
|
**kwargs): |
|
"""Load from pre-trained.""" |
|
model = cls(config, *inputs, **kwargs) |
|
if from_tf: |
|
raise ValueError( |
|
'Mosaic BERT does not support loading TensorFlow weights.') |
|
|
|
state_dict = torch.load(pretrained_checkpoint) |
|
|
|
consume_prefix_in_state_dict_if_present(state_dict, prefix='model.') |
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, |
|
strict=False) |
|
|
|
if len(missing_keys) > 0: |
|
logger.warning( |
|
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}" |
|
) |
|
if len(unexpected_keys) > 0: |
|
logger.warning( |
|
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}" |
|
) |
|
|
|
return model |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
pooled_output = outputs[1] |
|
|
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = 'regression' |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or |
|
labels.dtype == torch.int): |
|
self.config.problem_type = 'single_label_classification' |
|
else: |
|
self.config.problem_type = 'multi_label_classification' |
|
|
|
if self.config.problem_type == 'regression': |
|
loss_fct = nn.MSELoss() |
|
if self.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == 'single_label_classification': |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), |
|
labels.view(-1)) |
|
elif self.config.problem_type == 'multi_label_classification': |
|
loss_fct = nn.BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
|
|
|
|
class BertForMultipleChoice(BertPreTrainedModel): |
|
|
|
pass |
|
|
|
|
|
class BertForTokenClassification(BertPreTrainedModel): |
|
|
|
pass |
|
|
|
|
|
class BertForQuestionAnswering(BertPreTrainedModel): |
|
"""Bert Model with a span classification head. |
|
|
|
This is used for extractive question-answering tasks like SQuAD (a linear |
|
layers on top of the hidden states' output to compute `span start logits` |
|
and `span end logits`). |
|
""" |
|
|