from contextlib import nullcontext from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss from transformers import UMT5Model from .configuration_rankingprompter import RankingPrompterConfig @dataclass class RankingPrompterForPreTrainingOutput: loss: torch.FloatTensor = None logits: torch.FloatTensor = None class RankingPrompterForPreTraining(UMT5Model): config_class = RankingPrompterConfig _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", ] def __init__(self, config): # encoder, decoder and shared are from UMT5Model super().__init__(config) # add ranking head self.ranking_head = nn.Linear(config.d_model, 1) # Initialize weights and apply final processing self.post_init() # ctx for mixed precision training self.ctx = nullcontext() def enable_amp_ctx(self, device_type="cuda", dtype=torch.bfloat16): self.ctx = torch.amp.autocast(device_type=device_type, dtype=dtype) def disable_amp_ctx(self): self.ctx = nullcontext() def forward( self, document_input_ids: Optional[torch.LongTensor] = None, document_attention_mask: Optional[torch.FloatTensor] = None, question_input_ids: Optional[torch.LongTensor] = None, question_attention_mask: Optional[torch.BoolTensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], RankingPrompterForPreTrainingOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` Returns: ```""" use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # document_input_ids: [batch_size, num_doc, doc_seq_len] batch_size, num_doc, doc_seq_len = document_input_ids.shape # document_input_ids = document_input_ids.view(-1, doc_seq_len) # to [batch_size * num_doc, doc_seq_len] document_attention_mask = document_attention_mask.view(-1, doc_seq_len) # Convert encoder inputs in embeddings if needed with self.ctx: encoder_outputs = self.encoder( input_ids=document_input_ids, attention_mask=document_attention_mask, return_dict=return_dict, ) document_embeds = encoder_outputs[0] # repeat question inputs for each document # question_input_ids: [batch_size, question_seq_len] question_seq_len = question_input_ids.shape[1] question_input_ids = ( question_input_ids.unsqueeze(1) .expand(-1, num_doc, -1) .reshape(-1, question_seq_len) ) # [batch_size * num_doc, question_seq_len] question_attention_mask = ( question_attention_mask.unsqueeze(1) .expand(-1, num_doc, -1) .reshape(-1, question_seq_len) ) # [batch_size * num_doc, question_seq_len] # Decode with self.ctx: decoder_outputs = self.decoder( input_ids=question_input_ids, attention_mask=question_attention_mask, past_key_values=past_key_values, encoder_hidden_states=document_embeds, encoder_attention_mask=document_attention_mask, use_cache=use_cache, return_dict=return_dict, ) # [batch_size * num_doc, soft_prompt_len + question_seq_len, hidden_size] sequence_output = decoder_outputs[0] # [batch_size * num_doc, soft_prompt_len, hidden_size] question_seq_len = sequence_output.size(1) # [batch_size, num_doc, soft_prompt_len, hidden_size] soft_prompt_output = sequence_output.view( batch_size, num_doc, question_seq_len, -1 ) # [batch_size, num_doc, self.num_soft_prompt_tokens, hidden_size] -> [batch_size, num_doc, hidden_size] ranking_logits = self.ranking_head(soft_prompt_output.mean(dim=2)) # rank loss loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) ranking_logits = ranking_logits.view(batch_size, num_doc) loss = loss_fct(ranking_logits, labels) if not return_dict: output = (ranking_logits,) + decoder_outputs[1:] + encoder_outputs return ((loss,) + output) if loss is not None else output return RankingPrompterForPreTrainingOutput( loss=loss, logits=ranking_logits )