Impossible_llm / edge_probing /edge_probing.py
Yaning1001's picture
Add files using upload-large-folder tool
cb31cb8 verified
# edge_probing.py
# Author: Julie Kallini
# For importing utils
import sys
sys.path.append("..")
from utils import CHECKPOINT_READ_PATH, PERTURBATIONS, PAREN_MODELS, get_gpt2_tokenizer_with_markers
from gpt2_no_positional_encoding_model import GPT2NoPositionalEncodingLMHeadModel
from transformers import GPT2LMHeadModel
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from itertools import zip_longest
import torch
import tqdm
import argparse
import pandas as pd
import os
MAX_TRAINING_STEPS = 3000
CHECKPOINTS = list(range(200, MAX_TRAINING_STEPS+1, 200))
LAYERS = [1, 3, 6, 9, 12, "Avg Last 4"]
def get_layer_embedding(model, token_sequences, indices, layer=None):
# Pad input token sequences
input_ids = zip(*zip_longest(*token_sequences,
fillvalue=gpt2_tokenizer.eos_token_id))
input_ids = torch.tensor(list(input_ids)).to(device)
# Get GPT2 model's output
with torch.no_grad():
output = model(input_ids)
# Either get the hidden state of the specified layer or
# get the average of the last 4 hidden states
if layer is not None:
hidden_states = output.hidden_states[layer]
else:
hidden_states = output.hidden_states[-4:]
hidden_states = sum(hidden_states) / 4
# Create mask using start and end indices
batch_size, seq_length = input_ids.shape
mask = torch.full((batch_size, seq_length), 0).to(device)
for i, (start_idx, end_idx) in enumerate(indices):
mask[i, start_idx:end_idx] = 1
# Mask out embeddings of tokens outside indices
mask_expanded = mask.unsqueeze(-1).expand(hidden_states.size())
hidden_states = hidden_states * mask_expanded
return hidden_states
def max_pooling(tensor, index_tuples):
pooled_results = []
for i, (start, end) in enumerate(index_tuples):
# Extracting the embeddings corresponding to the specified range
embeddings = tensor[i, start:end, :]
# Performing max pooling
max_pooled = torch.max(embeddings, dim=0)[0]
pooled_results.append(max_pooled)
return torch.stack(pooled_results)
def mean_pooling(tensor, index_tuples):
batch_size, seq_len, embedding_size = tensor.shape
output = torch.empty(batch_size, embedding_size,
device=tensor.device, dtype=tensor.dtype)
for i, (start, end) in enumerate(index_tuples):
embeddings = tensor[i, start:end, :]
output[i, :] = torch.mean(embeddings, dim=0)
return output
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='Edge probing',
description='Edge probing experiments')
parser.add_argument('perturbation_type',
default='all',
const='all',
nargs='?',
choices=PERTURBATIONS.keys(),
help='Perturbation function used to transform BabyLM dataset')
parser.add_argument('train_set',
default='all',
const='all',
nargs='?',
choices=["100M", "10M"],
help='BabyLM train set')
parser.add_argument('random_seed', type=int, help="Random seed")
parser.add_argument('paren_model',
default='all',
const='all',
nargs='?',
choices=list(PAREN_MODELS.keys()) + ["randinit"],
help='Parenthesis model')
parser.add_argument('pooling_operation',
default='all',
const='all',
nargs='?',
choices=["mean", "max"],
help='Pooling operation to compute on embeddings')
parser.add_argument('-np', '--no_pos_encodings', action='store_true',
help="Train GPT-2 with no positional encodings")
# Get args
args = parser.parse_args()
if args.pooling_operation == "mean":
pooling_function = mean_pooling
elif args.pooling_operation == "max":
pooling_function = max_pooling
else:
raise Exception("Pooling operation undefined")
# Init tokenizer
gpt2_tokenizer = get_gpt2_tokenizer_with_markers([])
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
# Get path to model
no_pos_encodings_underscore = "_no_positional_encodings" if args.no_pos_encodings else ""
model = f"babylm_{args.perturbation_type}_{args.train_set}_{args.paren_model}{no_pos_encodings_underscore}_seed{args.random_seed}"
model_path = f"{CHECKPOINT_READ_PATH}/babylm_{args.perturbation_type}_{args.train_set}_{args.paren_model}{no_pos_encodings_underscore}/{model}/runs/{model}/checkpoint-"
# Get constituency parse data
if "hop" in args.perturbation_type:
phrase_df = pd.read_csv("phrase_data/hop_phrase_data.csv")
elif "reverse" in args.perturbation_type:
phrase_df = pd.read_csv("phrase_data/reverse_phrase_data.csv")
else:
raise Exception("Phrase data not found")
token_sequences = list(phrase_df["Sentence Tokens"])
if args.perturbation_type == "reverse_full":
indices = list(
zip(phrase_df["Rev Start Index"], phrase_df["Rev End Index"]))
else:
indices = list(zip(phrase_df["Start Index"], phrase_df["End Index"]))
labels = list(phrase_df["Category"])
BATCH_SIZE = 32
device = "cuda"
edge_probing_df = pd.DataFrame(LAYERS, columns=["GPT-2 Layer"])
for ckpt in CHECKPOINTS:
# Load model
if args.no_pos_encodings:
model = GPT2LMHeadModel.from_pretrained(
model_path + str(ckpt), output_hidden_states=True).to(device)
else:
model = GPT2NoPositionalEncodingLMHeadModel.from_pretrained(
model_path + str(ckpt), output_hidden_states=True).to(device)
layer_accuracies = []
for layer in LAYERS:
print(f"Checkpoint: {ckpt}, Layer: {layer}")
print("Computing span embeddings...")
# Iterate over token sequences and indices to get embeddings
spans = []
for i in tqdm.tqdm(list(range(0, len(token_sequences), BATCH_SIZE))):
tokens_batch = [[int(tok) for tok in seq.split()]
for seq in token_sequences[i:i+BATCH_SIZE]]
if args.perturbation_type == "reverse_full":
tokens_batch = [toks[::-1] for toks in tokens_batch]
index_batch = indices[i:i+BATCH_SIZE]
# Extract embeddings
if layer == "Avg Last 4":
embeddings = get_layer_embedding(
model, tokens_batch, index_batch, None)
else:
embeddings = get_layer_embedding(
model, tokens_batch, index_batch, layer)
pooled_results = pooling_function(embeddings, index_batch)
spans.extend(list(pooled_results))
# Get features and ground truth
X = torch.vstack(spans).detach().cpu().numpy()
y = labels
# Split the data; since we pass random seed, it
# will be the same split every time
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=args.random_seed)
# Fit L2-regularized linear classifier
clf = LogisticRegression(max_iter=10,
random_state=args.random_seed).fit(X_train, y_train)
# Get probe accuracy
y_test_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_test_pred)
layer_accuracies.append(acc)
print(f"Accuracy: {acc}")
edge_probing_df[f"Accuracy (ckpt {ckpt})"] = layer_accuracies
# Write results to CSV
nps = '_no_pos_encodings' if args.no_pos_encodings else ''
directory = f"edge_probing_results/{args.perturbation_type}_{args.train_set}{nps}"
if not os.path.exists(directory):
os.makedirs(directory)
file = directory + \
f"/{args.paren_model}_{args.pooling_operation}_pooling_seed{args.random_seed}.csv"
print(f"Writing results to CSV: {file}")
edge_probing_df.to_csv(file)