|
|
|
|
|
|
|
|
|
import sys |
|
sys.path.append("..") |
|
|
|
from utils_llama import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \ |
|
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file |
|
from glob import glob |
|
import numpy as np |
|
import itertools |
|
import json |
|
import os |
|
import tqdm |
|
import argparse |
|
import pytest |
|
|
|
MODEL_NAME = "Llama-3.2-3B" |
|
|
|
def lines_equivalent_3pres(file1_path, file2_path): |
|
"""Compare lines of two files after splitting them.""" |
|
with open(file1_path, 'r') as file1, open(file2_path, 'r') as file2: |
|
for line1, line2 in zip(file1, file2): |
|
|
|
res1 = [i for i in line1.split() if int( |
|
i) not in (marker_sg_token, marker_pl_token)] |
|
res2 = [i for i in line2.split() if int( |
|
i) not in (marker_sg_token, marker_pl_token)] |
|
if res1 != res2: |
|
print(line1) |
|
print(line2) |
|
return False |
|
|
|
|
|
if file1.readline() or file2.readline(): |
|
return False |
|
|
|
return True |
|
|
|
|
|
perturbation_pairs_3pres = [ |
|
("0tokens", "4tokens"), |
|
("0tokens", "4words"), |
|
("4tokens", "4words"), |
|
] |
|
|
|
|
|
|
|
test_data = itertools.product( |
|
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_3pres) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data) |
|
def test_3pres_all_equivalent(split, genre, perturbation_pair): |
|
|
|
perturbation1, perturbation2 = perturbation_pair |
|
|
|
if split in ("100M", "10M"): |
|
filename = f"{genre}.train" |
|
elif split == "test_affected": |
|
filename = f"{genre}_affected.test" |
|
elif split == "test_unaffected": |
|
filename = f"{genre}_unaffected.test" |
|
elif split == "dev": |
|
filename = f"{genre}.dev" |
|
|
|
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_3pres_{perturbation1}/babylm_{split}/{filename}" |
|
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_3pres_{perturbation2}/babylm_{split}/{filename}" |
|
|
|
|
|
assert lines_equivalent_3pres(path1, path2), f"File {filename} of " + \ |
|
f"3pres_{perturbation1} and 3pres_{perturbation2} have non-equivalent lines!" |
|
|
|
|
|
def lines_equivalent_reversal(rev_path, ident_path): |
|
"""Compare lines of reversal file and identity file after splitting them.""" |
|
with open(rev_path, 'r') as file1, open(ident_path, 'r') as file2: |
|
for line1, line2 in zip(file1, file2): |
|
|
|
line1_tokens = line1.split() |
|
line2_tokens = line2.split() |
|
|
|
|
|
marker_index = line1_tokens.index(str(marker_rev_token)) |
|
|
|
|
|
if line1_tokens[:marker_index+1] != line2_tokens[:marker_index+1]: |
|
return False |
|
|
|
|
|
line1_tokens_rev = line1_tokens[marker_index+1:].copy() |
|
line1_tokens_rev.reverse() |
|
if line1_tokens_rev != line2_tokens[marker_index+1:]: |
|
return False |
|
|
|
|
|
if file1.readline() or file2.readline(): |
|
return False |
|
|
|
return True |
|
|
|
|
|
perturbation_pairs_reversal = [ |
|
("reversal", "reversal_identity"), |
|
] |
|
|
|
|
|
test_data = itertools.product( |
|
["100M", "dev", "test_affected"], GENRES.keys(), perturbation_pairs_reversal) |
|
|
|
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data) |
|
def test_reversal_all_equivalent(split, genre, perturbation_pair): |
|
|
|
perturbation1, perturbation2 = perturbation_pair |
|
|
|
if split in ("100M", "10M"): |
|
filename = f"{genre}.train" |
|
elif split == "test_affected": |
|
filename = f"{genre}_affected.test" |
|
elif split == "test_unaffected": |
|
filename = f"{genre}_unaffected.test" |
|
elif split == "dev": |
|
filename = f"{genre}.dev" |
|
|
|
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation1}/babylm_{split}/{filename}" |
|
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation2}/babylm_{split}/{filename}" |
|
|
|
assert lines_equivalent_reversal(path1, path2), f"File {filename} of " + \ |
|
f"{perturbation1} and {perturbation2} have non-equivalent lines!" |
|
|
|
|
|
def lines_equivalent_determiner_swap(det_path, ident_path): |
|
"""Compare lines of reversal file and identity file after splitting them.""" |
|
with open(det_path, 'r') as file1, open(ident_path, 'r') as file2: |
|
for line1, line2 in zip(file1, file2): |
|
|
|
line1_tokens = set(line1.split()) |
|
line2_tokens = set(line2.split()) |
|
if line1_tokens != line2_tokens: |
|
print(line1.split()) |
|
print(line2.split()) |
|
return False |
|
|
|
|
|
if file1.readline() or file2.readline(): |
|
return False |
|
|
|
return True |
|
|
|
|
|
perturbation_pairs_reversal = [ |
|
("determiner_swap", "determiner_swap_identity"), |
|
] |
|
test_data = itertools.product( |
|
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_reversal) |
|
|
|
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data) |
|
def test_determiner_swap_all_equivalent(split, genre, perturbation_pair): |
|
|
|
perturbation1, perturbation2 = perturbation_pair |
|
|
|
if split in ("100M", "10M"): |
|
filename = f"{genre}.train" |
|
elif split == "test_affected": |
|
filename = f"{genre}_affected.test" |
|
elif split == "test_unaffected": |
|
filename = f"{genre}_unaffected.test" |
|
elif split == "dev": |
|
filename = f"{genre}.dev" |
|
|
|
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation1}/babylm_{split}/{filename}" |
|
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_llama/babylm_{perturbation2}/babylm_{split}/{filename}" |
|
|
|
assert lines_equivalent_determiner_swap(path1, path2), f"File {filename} of " + \ |
|
f"{perturbation1} and {perturbation2} have non-equivalent lines!" |
|
|
|
|
|
def flatten_list(l): |
|
"""Function to flatten a nested list.""" |
|
return list(itertools.chain.from_iterable(l)) |
|
|
|
|
|
def process_line(line): |
|
""" |
|
Process a given line from the dataset, apply transformations to its sentences, |
|
and categorize them into affected or unaffected based on the transformation. |
|
|
|
Parameters: |
|
- line (dict): A dictionary representing a line from the dataset, which contains |
|
sentence annotations. |
|
|
|
Returns: |
|
- tuple: A tuple containing three lists: |
|
1. new_lines_affected (list of str): Sentences that were affected by the transformation. |
|
2. new_lines_unaffected (list of str): Sentences that were not affected by the transformation. |
|
|
|
Note: |
|
- The transformation functions (`perturbation_function`, `affect_function`, `filter_function`) |
|
are expected to be available in the global scope. |
|
""" |
|
|
|
new_lines_affected = [] |
|
new_lines_unaffected = [] |
|
sents_unaffected = [] |
|
|
|
|
|
for sent in line["sent_annotations"]: |
|
|
|
tokens = perturbation_function(sent) |
|
if len([tok for tok in tokens if tok not in MARKER_TOKEN_IDS]) <= 1: |
|
continue |
|
|
|
token_line = " ".join([str(tok) for tok in tokens]) |
|
|
|
|
|
if affect_function(sent): |
|
|
|
|
|
if filter_function(sent): |
|
new_lines_affected.append(token_line + "\n") |
|
|
|
else: |
|
new_lines_unaffected.append(token_line + "\n") |
|
sents_unaffected.append(sent["sent_text"] + "\n") |
|
|
|
return new_lines_affected, new_lines_unaffected, sents_unaffected |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser( |
|
prog='Perturb BabyLM dataset', |
|
description='Perturb BabyLM dataset by altering POS-tagged data') |
|
parser.add_argument('perturbation_type', |
|
default='all', |
|
const='all', |
|
nargs='?', |
|
choices=PERTURBATIONS.keys(), |
|
help='Perturbation function used to transform BabyLM dataset') |
|
parser.add_argument('babylm_dataset', |
|
default='all', |
|
const='all', |
|
nargs='?', |
|
choices=BABYLM_SPLITS, |
|
help='BabyLM dataset choice') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
babylm_dataset = args.babylm_dataset |
|
json_ext = "_parsed.json" |
|
|
|
babylm_data = glob(f"babylm_data/babylm_{babylm_dataset}/*{json_ext}") |
|
print("babylm_data:", babylm_data) |
|
|
|
|
|
perturbation_function = PERTURBATIONS[args.perturbation_type]['perturbation_function'] |
|
affect_function = PERTURBATIONS[args.perturbation_type]['affect_function'] |
|
filter_function = PERTURBATIONS[args.perturbation_type]['filter_function'] |
|
llama_tokenizer = PERTURBATIONS[args.perturbation_type]['llama_tokenizer'] |
|
|
|
if babylm_dataset == "test": |
|
|
|
|
|
for file in babylm_data: |
|
print(file) |
|
f = open(file) |
|
data = json.load(f) |
|
f.close() |
|
|
|
|
|
results = [] |
|
for line in tqdm.tqdm(data): |
|
results.append(process_line(line)) |
|
|
|
new_lines_affected, new_lines_unaffected, unaffected_sents = zip( |
|
*results) |
|
new_lines_affected = flatten_list(new_lines_affected) |
|
new_lines_unaffected = flatten_list(new_lines_unaffected) |
|
unaffected_sents = flatten_list(unaffected_sents) |
|
|
|
|
|
new_file_affected = os.path.basename( |
|
file).replace(json_ext, "_affected.test") |
|
new_file_unaffected = os.path.basename( |
|
file).replace(json_ext, "_unaffected.test") |
|
file_unaffected_sents = os.path.basename( |
|
file).replace(json_ext, "_unaffected_sents.test") |
|
|
|
|
|
data_write_directory = f"{BABYLM_DATA_PATH}/Perturbed_data/{MODEL_NAME}" |
|
directory_affected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_affected/" |
|
if not os.path.exists(directory_affected): |
|
os.makedirs(directory_affected) |
|
directory_unaffected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected/" |
|
if not os.path.exists(directory_unaffected): |
|
os.makedirs(directory_unaffected) |
|
directory_unaffected_sents = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected_sents/" |
|
if not os.path.exists(directory_unaffected_sents): |
|
os.makedirs(directory_unaffected_sents) |
|
|
|
|
|
write_file(directory_affected, |
|
new_file_affected, new_lines_affected) |
|
write_file(directory_unaffected, |
|
new_file_unaffected, new_lines_unaffected) |
|
write_file(directory_unaffected_sents, |
|
file_unaffected_sents, unaffected_sents) |
|
|
|
else: |
|
|
|
|
|
for file in babylm_data: |
|
print(file) |
|
f = open(file) |
|
data = json.load(f) |
|
f.close() |
|
|
|
|
|
results = [] |
|
for line in tqdm.tqdm(data): |
|
results.append(process_line(line)) |
|
|
|
new_lines_affected, new_lines_unaffected, _ = zip( |
|
*results) |
|
|
|
new_lines_affected = flatten_list(new_lines_affected) |
|
new_lines_unaffected = flatten_list(new_lines_unaffected) |
|
|
|
|
|
new_lines = new_lines_unaffected + new_lines_affected |
|
|
|
|
|
if babylm_dataset == "dev": |
|
new_file = os.path.basename(file).replace(json_ext, ".dev") |
|
elif babylm_dataset == 'unittest': |
|
new_file = os.path.basename(file).replace(json_ext, ".test") |
|
|
|
|
|
new_lines_decoded = [llama_tokenizer.decode( |
|
[int(tok) for tok in line.split()]) + "\n" for line in new_lines] |
|
new_lines_with_strings = [] |
|
for tokens, line in list(zip(new_lines, new_lines_decoded)): |
|
new_lines_with_strings.append(tokens) |
|
new_lines_with_strings.append(line) |
|
new_lines = new_lines_with_strings |
|
|
|
else: |
|
new_file = os.path.basename(file).replace(json_ext, ".train") |
|
|
|
|
|
directory = f"{BABYLM_DATA_PATH}/Perturbed_data/{MODEL_NAME}/babylm_{args.perturbation_type}/babylm_{babylm_dataset}/" |
|
print("directory:", directory) |
|
if not os.path.exists(directory): |
|
os.makedirs(directory) |
|
write_file(directory, new_file, new_lines) |