|
import torch |
|
from transformers import GPT2Tokenizer, AutoModelForCausalLM |
|
start_token = "<|ASSISTANT|>" |
|
end_token = "<|" |
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large') |
|
model = AutoModelForCausalLM.from_pretrained('gpt2-large', torch_dtype=torch.bfloat16) |
|
tokenizer.pad_token = "[PAD]" |
|
tokenizer.eos_token = "<|endoftext|>" |
|
tokenizer.add_special_tokens({"additional_special_tokens": ["<|ASSISTANT|>", "<|USER|>", "<|SYSTEM|>"]}) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
model.load_state_dict(torch.load("/media/locutusque/T7/Projects/results/pytorch_model.bin")) |
|
model.cuda() |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
def generate_text(model, tokenizer, prompt, max_length=1024): |
|
prompt = f'<|USER|> {prompt} <|ASSISTANT|> ' |
|
input_ids = tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt").to(device) |
|
attention_mask = torch.ones_like(input_ids).to(device) |
|
output = model.generate(input_ids, |
|
max_length=max_length, |
|
do_sample=True, |
|
top_k=0, |
|
top_p=0.1, |
|
temperature=0.75, |
|
repetition_penalty=1.176, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
attention_mask=attention_mask) |
|
output_ids = tokenizer.decode(output[0], skip_special_tokens=False) |
|
return output_ids |
|
|
|
while True: |
|
prompt = input("Enter a prompt (or 'q' to quit): ") |
|
if prompt == "q": |
|
break |
|
output_text = generate_text(model, tokenizer, prompt) |
|
text_between_tokens = output_text[output_text.find(start_token) + len(start_token):] |
|
out = text_between_tokens[:text_between_tokens.find(end_token)] |
|
print(out) |
|
|