import torch import torch.nn.functional as F from transformers import PreTrainedModel from .original import TransformerModel, LMHead ''' Code for HuggingFace Hub Compatability ''' class HF_LMModel(PreTrainedModel): """ Transformer with language model head only """ def __init__(self, config): super().__init__(config) self.transformer = TransformerModel(config, vocab=config.n_vocab, n_ctx=config.n_ctx) self.lm_head = LMHead(self.transformer, config, trunc_and_reshape=False) self.return_probs = config.return_probs self.return_acts = config.return_acts if self.return_probs or self.return_acts: pos_emb_mask = torch.zeros(1, 1, config.n_vocab) pos_emb_mask[:, :, -config.n_ctx:] = -1e12 self.register_buffer('pos_emb_mask', pos_emb_mask) def forward(self, x, sequence_mask=None): h = self.transformer(x, sequence_mask) lm_logits = self.lm_head(h) if self.return_probs: lm_logits = F.softmax(lm_logits + self.pos_emb_mask, dim=-1) elif self.return_acts: lm_logits = lm_logits + self.pos_emb_mask return { "logits": lm_logits }