import torch import torch.nn.functional as F from torch import nn from transformers import BloomForCausalLM, PreTrainedModel from .configuration import BufferEmbeddingConfig class DualModel(PreTrainedModel): config_class = BufferEmbeddingConfig _auto_class = "AutoModel" def __init__(self, config): super(DualModel, self).__init__(config) self.model = BloomForCausalLM(config)#.from_pretrained('Langboat/bloom-800m-zh') self.classifier = nn.Linear(1536, 1536) self.hidden = nn.Sequential(nn.Linear(1536, 1536), nn.Tanh()) def forward(self, input_ids, token_type_ids=None, position_ids_ids=None, attention_mask=None, labels=None ): attention_mask = torch.ne(input_ids, 3) # size: batch_size, max_len y = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True) embedding = (y.hidden_states[-1]*attention_mask.unsqueeze(-1)).sum(1)/attention_mask.sum(1).unsqueeze(-1) embedding = self.classifier(self.hidden(embedding)) return F.normalize(embedding, p=2, dim=-1)