XciD's picture
XciD HF staff
initial commit
8969f81
raw
history blame
1.9 kB
import torch
def forward(model_name, model, input_ids, past, device='cpu'):
if "gpt2" in model_name or "ctrl" in model_name:
if past is not None:
return model(input_ids[:, -1], past=past)
return model(input_ids)
elif "xlnet" in model_name:
input_ids = torch.cat((
input_ids,
torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=device)
), dim=1)
perm_mask = torch.zeros(
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]),
dtype=torch.float,
device=device
)
perm_mask[:, :, -1] = 1.0
target_mapping = torch.zeros(
(input_ids.shape[0], 1, input_ids.shape[1]),
dtype=torch.float,
device=device)
target_mapping[:, 0, -1] = 1.0
return model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
elif "transfo-xl" in model_name:
return model(input_ids, mems=past)
else:
return model(input_ids)
def create_context(model_name, tokenizer, initial_text="", padding_text=None, max_tokens=512):
if not len(initial_text) and "gpt2" in model_name:
initial_text = "<|endoftext|>"
if 'xlnet' in model_name or "transfo-xl" in model_name:
initial_text = padding_text + initial_text
if 'transfo-xl' in model_name:
max_tokens = int(max_tokens / 2)
context_tokens = tokenizer.encode(initial_text)[-max_tokens:]
if "gpt2" in model_name:
eot_token = tokenizer.encoder["<|endoftext|>"]
if len(context_tokens) == 0:
context_tokens = [tokenizer.encoder["<|endoftext|>"]]
elif "xlnet" in model_name:
eot_token = tokenizer.convert_tokens_to_ids('<eop>')
else:
eot_token = None
dot_token = tokenizer.encode(".")[-1]
return context_tokens, eot_token, dot_token