File size: 1,190 Bytes
5d30583
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn.functional as F

from transformers import PreTrainedModel
from .original import TransformerModel, LMHead
'''
Code for HuggingFace Hub Compatability
'''

class HF_LMModel(PreTrainedModel):
    """ Transformer with language model head only """
    def __init__(self, config):
        super().__init__(config)
        self.transformer = TransformerModel(config, vocab=config.n_vocab, n_ctx=config.n_ctx)
        self.lm_head = LMHead(self.transformer, config, trunc_and_reshape=False)
        self.return_probs = config.return_probs
        self.return_acts = config.return_acts
        if self.return_probs or self.return_acts:
            pos_emb_mask = torch.zeros(1, 1, config.n_vocab)
            pos_emb_mask[:, :, -config.n_ctx:] = -1e12
            self.register_buffer('pos_emb_mask', pos_emb_mask)

    def forward(self, x, sequence_mask=None):
        h = self.transformer(x, sequence_mask)
        lm_logits = self.lm_head(h)
        if self.return_probs:
            lm_logits = F.softmax(lm_logits + self.pos_emb_mask, dim=-1)
        elif self.return_acts:
            lm_logits = lm_logits + self.pos_emb_mask
        return { "logits": lm_logits }