|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers.models.xlm_roberta.modeling_xlm_roberta import \ |
|
create_position_ids_from_input_ids |
|
|
|
|
|
class XLMRobertaEmbeddings(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim, |
|
vocab_size, |
|
max_position_embeddings, |
|
type_vocab_size, |
|
padding_idx=None, |
|
device=None, |
|
dtype=None, |
|
): |
|
""" |
|
If max_position_embeddings <= 0, there's no position embeddings |
|
If type_vocab_size <= 0, there's no token type embeddings |
|
""" |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
self.word_embeddings = nn.Embedding( |
|
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs |
|
) |
|
self.max_position_embeddings = max_position_embeddings |
|
self.type_vocab_size = type_vocab_size |
|
if self.max_position_embeddings > 0: |
|
self.position_embeddings = nn.Embedding( |
|
max_position_embeddings, embed_dim, **factory_kwargs |
|
) |
|
if self.type_vocab_size > 0: |
|
self.token_type_embeddings = nn.Embedding( |
|
type_vocab_size, embed_dim, **factory_kwargs |
|
) |
|
|
|
def forward( |
|
self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None |
|
): |
|
""" |
|
input_ids: (batch, seqlen) |
|
position_ids: (batch, seqlen) |
|
token_type_ids: (batch, seqlen) |
|
adapter_mask: (batch, 1) |
|
""" |
|
batch_size, seqlen = input_ids.shape |
|
if adapter_mask is not None: |
|
unique_tasks = torch.unique(adapter_mask) |
|
embedding_dtype = next(self.word_embeddings.parameters()).dtype |
|
embeddings = torch.empty( |
|
*input_ids.shape, |
|
self.word_embeddings.embedding_dim, |
|
dtype=embedding_dtype, |
|
device=input_ids.device |
|
) |
|
for task_id in unique_tasks: |
|
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0] |
|
task_input_ids = input_ids[task_indices] |
|
task_embeddings = self.word_embeddings(task_input_ids, task_id=task_id) |
|
embeddings[task_indices] = task_embeddings |
|
else: |
|
embeddings = self.word_embeddings(input_ids) |
|
if self.max_position_embeddings > 0: |
|
if position_ids is None: |
|
position_ids = create_position_ids_from_input_ids( |
|
input_ids, padding_idx=self.word_embeddings.padding_idx |
|
).to(input_ids.device) |
|
position_embeddings = self.position_embeddings(position_ids) |
|
embeddings = embeddings + position_embeddings |
|
if self.type_vocab_size > 0: |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros( |
|
seqlen, dtype=torch.long, device=input_ids.device |
|
) |
|
|
|
if adapter_mask is not None: |
|
unique_tasks = torch.unique(adapter_mask) |
|
for task_id in unique_tasks: |
|
task_token_type_embeddings = self.token_type_embeddings( |
|
token_type_ids, task_id=task_id |
|
) |
|
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0] |
|
embeddings[task_indices] = ( |
|
embeddings[task_indices] + task_token_type_embeddings |
|
) |
|
else: |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
embeddings = embeddings + token_type_embeddings |
|
return embeddings |
|
|