#!/usr/bin/env python3 import os from transformers import AutoTokenizer, GPT2Tokenizer #from megatron.initialize import initialize_megatron from metaseq import checkpoint_utils from transformers import OPTForCausalLM import torch path = "./model" hf_path = "/home/patrick/facebook/opt-1.3b" vocab_file = os.path.join(path, "gpt2-vocab.json") merges_file = os.path.join(path, "gpt2-merges.txt") tokenizer = GPT2Tokenizer(vocab_file, merges_file) tokenizer.save_pretrained(path) checkpoint = checkpoint_utils.load_model_ensemble_and_task( [os.path.join(path, "restored.pt")], arg_overrides={ "vocab_filename": vocab_file, "merges_filename": merges_file, } ) model = checkpoint[0][0].eval() model = model hf_model = OPTForCausalLM.from_pretrained(hf_path) # forward passes def single_batch_forward_logits(prompts): input_ids = tokenizer(prompts, return_tensors="pt").input_ids input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1) input_ids = input_ids with torch.no_grad(): logits = model(input_ids)[0] return logits # forward hf def forward_hf(prompts): input_ids = tokenizer(prompts, return_tensors="pt").input_ids input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1) input_ids = input_ids with torch.no_grad(): logits = hf_model(input_ids)[0] return logits prompts = [ "Today is a beautiful day and I want to", "In the city of", "Paris is the capital of France and", "Computers and mobile phones have taken", ] print("Next word generation") for prompt in prompts: print("-------------") print(f"Prompt: {prompt}...\n") logits_fsq = single_batch_forward_logits(prompt) pred_next_token = torch.argmax(logits_fsq[0, -1], -1) next_token = tokenizer.convert_ids_to_tokens([pred_next_token]) next_token = next_token[0].replace("Ġ", "") print(f"Next word: {next_token}") print("-------------") logits = forward_hf(prompt) pred_next_token = torch.argmax(logits[0, -1], -1) next_token = tokenizer.convert_ids_to_tokens([pred_next_token]) next_token = next_token[0].replace("Ġ", "") print(f"Next word: {next_token}") print("-------------") print("Is equal:", torch.allclose(logits_fsq.cpu(), logits.cpu(), atol=1e-3))