|
import torch |
|
import torch.nn.functional as F |
|
|
|
from transformers import PreTrainedModel |
|
from .original import TransformerModel, LMHead |
|
''' |
|
Code for HuggingFace Hub Compatability |
|
''' |
|
|
|
class HF_LMModel(PreTrainedModel): |
|
""" Transformer with language model head only """ |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.transformer = TransformerModel(config, vocab=config.n_vocab, n_ctx=config.n_ctx) |
|
self.lm_head = LMHead(self.transformer, config, trunc_and_reshape=False) |
|
self.return_probs = config.return_probs |
|
self.return_acts = config.return_acts |
|
if self.return_probs or self.return_acts: |
|
pos_emb_mask = torch.zeros(1, 1, config.n_vocab) |
|
pos_emb_mask[:, :, -config.n_ctx:] = -1e12 |
|
self.register_buffer('pos_emb_mask', pos_emb_mask) |
|
|
|
def forward(self, x, sequence_mask=None): |
|
h = self.transformer(x, sequence_mask) |
|
lm_logits = self.lm_head(h) |
|
if self.return_probs: |
|
lm_logits = F.softmax(lm_logits + self.pos_emb_mask, dim=-1) |
|
elif self.return_acts: |
|
lm_logits = lm_logits + self.pos_emb_mask |
|
return { "logits": lm_logits } |