import torch import sys import argparse import os sys.path.append("..") from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling from datasets import load_dataset from numpy.random import default_rng os.environ["TOKENIZERS_PARALLELISM"] = "false" MODEL_NAME_SAVE = "Llama-3.2-3B" FILE_SAMPLE_SIZE = 5 def get_perplexities(model, eval_dataset, batch_size): data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) training_args = TrainingArguments( output_dir="./tmp_trainer", per_device_eval_batch_size=batch_size, fp16=True, report_to="none" ) trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset, data_collator=data_collator) eval_results = trainer.evaluate() print("eval_results:", eval_results) loss = eval_results['eval_loss'] perplexity = torch.exp(torch.tensor(loss)).item() return perplexity if __name__ == "__main__": parser = argparse.ArgumentParser(description="Calculate perplexity on test dataset.") parser.add_argument('perturbation', type=str, default='reverse_full', nargs='?', help='Type of perturbation to use.') parser.add_argument('train_set', type=str, default='test', nargs='?', help='Dataset size for training.') parser.add_argument('batch_size', type=int, default=4, nargs='?', help='Batch size for evaluation.') parser.add_argument('seed', type=int, default=0, nargs='?', help='Random seed.') args = parser.parse_args() dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}" dataset = load_dataset('../train/babylm_dataset_test.py', name=dataset_name, trust_remote_code=True) test_dataset = dataset['test'] # Load test dataset print(test_dataset) checkpoint_path = f'../train/checkpoints/{MODEL_NAME_SAVE}/babylm_{args.perturbation}_10M_seed0/runs/checkpoint-450' rng = default_rng(args.seed) indices = rng.choice(len(test_dataset), FILE_SAMPLE_SIZE, replace=False) sampled_test_dataset = test_dataset.select(indices) tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) model = AutoModelForCausalLM.from_pretrained(checkpoint_path) model.eval() if torch.cuda.is_available(): model.to('cuda') def tokenize_function(examples): return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024) tokenized_test = sampled_test_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) perplexity = get_perplexities(model, tokenized_test, 1) print(f"Perplexity on test set: {perplexity}")