|
import sys |
|
sys.path.append("..") |
|
|
|
from utils_qwen import CHECKPOINT_READ_PATH, PERTURBATIONS, BABYLM_DATA_PATH, \ |
|
PAREN_MODELS |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import get_peft_model, LoraConfig, TaskType |
|
from tqdm import tqdm |
|
from glob import glob |
|
from numpy.random import default_rng |
|
from safetensors import safe_open |
|
import pandas as pd |
|
import torch |
|
import itertools |
|
import argparse |
|
import os |
|
|
|
|
|
FILE_SAMPLE_SIZE = 1000 |
|
BATCH_SIZE = 8 |
|
device = "cuda" |
|
|
|
MODEL_NAME = "Qwen/Qwen2.5-0.5B" |
|
MODEL_NAME_SAVE = "Qwen2.5-0.5B" |
|
|
|
checkpoint_path = 'checkpoint-2000' |
|
checkpoint_dir = f'../train/checkpoints/babylm/babylm_shuffle_nondeterministic_10M_seed0/runs/{checkpoint_path}' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_attention_mask(token_lists): |
|
seq_length = max([len(i) for i in token_lists]) |
|
batch_size = len(token_lists) |
|
mask = torch.full((batch_size, seq_length), 0) |
|
|
|
for i, tokens in enumerate(token_lists): |
|
mask[i, 0:len(tokens)] = 1 |
|
|
|
return mask |
|
|
|
def create_input_ids(token_lists, pad_token_id): |
|
padded = zip(*itertools.zip_longest(*token_lists, fillvalue=pad_token_id)) |
|
return torch.tensor(list(padded)) |
|
|
|
def get_perplexities(model, token_lists, pad_token_id, device="cuda"): |
|
input_ids = create_input_ids(token_lists, pad_token_id).to(device) |
|
labels = input_ids.clone() |
|
attention_mask = create_attention_mask(token_lists).to(device) |
|
|
|
outputs = model(input_ids=input_ids, labels=labels, |
|
attention_mask=attention_mask) |
|
shift_logits = outputs.logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
shift_attention_mask = attention_mask[..., 1:].contiguous() |
|
|
|
loss_fct = torch.nn.CrossEntropyLoss(reduction='none') |
|
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), |
|
shift_labels.view(-1)) |
|
|
|
loss = loss.view(shift_labels.size()) |
|
loss = loss * shift_attention_mask |
|
per_example_loss = loss.sum(dim=1) / shift_attention_mask.sum(dim=1) |
|
return torch.exp(per_example_loss).tolist() |
|
|
|
def models_are_equal(model1, model2): |
|
if type(model1) != type(model2): |
|
return False |
|
|
|
for param1, param2 in zip(model1.parameters(), model2.parameters()): |
|
if not torch.equal(param1.data, param2.data): |
|
return False |
|
|
|
return True |
|
|
|
def print_lora_output(module, input, output): |
|
|
|
print(f"{module.__class__.__name__} output with LoRA: {output}") |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
prog='Edge probing', |
|
description='Edge probing experiments') |
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('test_perturbation_type', |
|
default='all', |
|
const='all', |
|
nargs='?', |
|
choices=PERTURBATIONS.keys(), |
|
help='Perturbation function used to transform test BabyLM dataset') |
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('random_seed', type=int, help="Random seed") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
test_files = sorted(glob( |
|
f"../data/babylm_data_perturbed_qwen/babylm_{args.test_perturbation_type}/babylm_test_affected/*")) |
|
|
|
rng = default_rng(args.random_seed) |
|
|
|
print("Sampling BabyLM affected test files to extract surprisals...") |
|
token_sequences = [] |
|
print("test_files:", test_files) |
|
for test_file in test_files: |
|
print(test_file) |
|
with open(test_file, 'r') as f: |
|
file_token_sequences = [ |
|
[int(s) for s in l.split()] for l in f.readlines()] |
|
sample_indices = rng.choice( |
|
list(range(len(file_token_sequences))), FILE_SAMPLE_SIZE, replace=False) |
|
file_token_sequences = [file_token_sequences[i] |
|
for i in sample_indices] |
|
token_sequences.extend(file_token_sequences) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(checkpoint_dir).to(device) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir) |
|
|
|
test_sents = [tokenizer.decode(toks) for toks in token_sequences] |
|
|
|
perplexities = [] |
|
for i in tqdm(range(0, len(token_sequences), BATCH_SIZE)): |
|
batch = token_sequences[i:i+BATCH_SIZE] |
|
ppls = get_perplexities( |
|
model, batch, tokenizer.eos_token_id) |
|
perplexities.extend(ppls) |
|
|
|
ppl_df = pd.DataFrame({ |
|
"Sentences": test_sents, |
|
'Perplexities': perplexities |
|
}) |
|
|
|
directory = f"perplexity_results" |
|
if not os.path.exists(directory): |
|
os.makedirs(directory) |
|
print("directoty:", directory) |
|
file = f"{directory}/{MODEL_NAME_SAVE}/{args.test_perturbation_type}/{MODEL_NAME_SAVE}_seed{args.random_seed}_test_{args.test_perturbation_type}{checkpoint_path}.csv" |
|
print("file:", file) |
|
print(f"Writing results to CSV: {file}") |
|
ppl_df.to_csv(file, index=False) |