import torch import torch.nn as nn from transformers import RoFormerModel, RoFormerPreTrainedModel class RoFormerForSparseEmbedding(RoFormerPreTrainedModel): def __init__(self, config): super().__init__(config) self.encoder = RoFormerModel(config) self.linear_layer = nn.Linear(config.hidden_size, 1) # Initialize weights and apply final processing self.post_init() def forward(self, input_ids, attention_mask, return_sparse=False): B, L = input_ids.shape last_hidden_states = self.encoder(input_ids, attention_mask)['last_hidden_state'] # [B,L,D] token_weights = self.linear_layer(last_hidden_states).squeeze(-1) # [B,L] token_mask = (1 - attention_mask) * -1e4 # [B,L] token_mask[:, 0] = -1e4 last_ind = torch.sum(attention_mask, -1, keepdim=True) - 1 # [B,1] token_mask = torch.scatter(token_mask, -1, last_ind, -1e4) # [B,L] token_weights = token_weights + token_mask # [B,L] emb = torch.zeros(B, L, self.encoder.config.vocab_size, dtype=token_weights.dtype, device=token_weights.device) # [B,L,V] emb = torch.scatter(emb, dim=-1, index=input_ids.unsqueeze(-1), src=token_weights.unsqueeze(-1)) # [B,L,V] emb = torch.max(torch.relu(emb), dim=-2).values # [B,V] if return_sparse: emb = emb.to_sparse() return emb