fzbuzz's picture
Upload model
5d30583
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel
from .original import TransformerModel, LMHead
'''
Code for HuggingFace Hub Compatability
'''
class HF_LMModel(PreTrainedModel):
""" Transformer with language model head only """
def __init__(self, config):
super().__init__(config)
self.transformer = TransformerModel(config, vocab=config.n_vocab, n_ctx=config.n_ctx)
self.lm_head = LMHead(self.transformer, config, trunc_and_reshape=False)
self.return_probs = config.return_probs
self.return_acts = config.return_acts
if self.return_probs or self.return_acts:
pos_emb_mask = torch.zeros(1, 1, config.n_vocab)
pos_emb_mask[:, :, -config.n_ctx:] = -1e12
self.register_buffer('pos_emb_mask', pos_emb_mask)
def forward(self, x, sequence_mask=None):
h = self.transformer(x, sequence_mask)
lm_logits = self.lm_head(h)
if self.return_probs:
lm_logits = F.softmax(lm_logits + self.pos_emb_mask, dim=-1)
elif self.return_acts:
lm_logits = lm_logits + self.pos_emb_mask
return { "logits": lm_logits }