|
|
|
|
|
|
|
|
|
import sys |
|
sys.path.append("..") |
|
|
|
import pandas as pd |
|
from glob import glob |
|
import re |
|
from utils import gpt2_hop_tokenizer, BABYLM_DATA_PATH, marker_sg_token |
|
from pluralizer import Pluralizer |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
get_vocab_dict = [] |
|
vocab = gpt2_hop_tokenizer.vocab |
|
|
|
|
|
control_pattern = f"(?:^| )(?:{vocab['The']}|{vocab['Ġthe']}) [1-9]+ [1-9]+ {marker_sg_token} [1-9]+ [1-9]+ [1-9]+ [1-9]+" |
|
words4_pattern = f"(?:^| )(?:{vocab['The']}|{vocab['Ġthe']}) [1-9]+ [1-9]+ [1-9]+ [1-9]+ [1-9]+ [1-9]+ {marker_sg_token}" |
|
|
|
|
|
test_file_path = f'{BABYLM_DATA_PATH}/babylm_data_perturbed/' + \ |
|
'babylm_hop_{}/babylm_test_affected/*' |
|
control_files = sorted(glob(test_file_path.format("control"))) |
|
words4_files = sorted(glob(test_file_path.format("words4"))) |
|
|
|
|
|
candidate_sequences = [] |
|
for control_file_path, words4_file_path in zip(control_files, words4_files): |
|
print(control_file_path.split("/")[-1]) |
|
assert control_file_path.split( |
|
"/")[-1] == words4_file_path.split("/")[-1] |
|
|
|
|
|
control_lines = open(control_file_path, 'r').readlines() |
|
words4_lines = open(words4_file_path, 'r').readlines() |
|
for cl, wl in zip(control_lines, words4_lines): |
|
|
|
|
|
cseqs = re.findall(control_pattern, cl) |
|
wseqs = re.findall(words4_pattern, wl) |
|
|
|
|
|
for cseq, wseq in zip(re.findall(control_pattern, cl), re.findall(words4_pattern, wl)): |
|
if cseq.replace(" " + str(marker_sg_token), "") == wseq.replace(" " + str(marker_sg_token), ""): |
|
candidate_sequences.append( |
|
[int(s) for s in wseq.replace( |
|
str(marker_sg_token), "").split()] |
|
) |
|
|
|
|
|
pluralizer = Pluralizer() |
|
|
|
|
|
data = [] |
|
for seq in candidate_sequences: |
|
|
|
|
|
splitted = gpt2_hop_tokenizer.decode(seq).split() |
|
|
|
|
|
splitted[1] = pluralizer.pluralize(splitted[1], 2, False) |
|
|
|
|
|
plur_seq = gpt2_hop_tokenizer.encode(" ".join(splitted)) |
|
|
|
|
|
if len(plur_seq) != len(seq): |
|
continue |
|
|
|
|
|
plur_seq[0] = seq[0] |
|
|
|
|
|
if seq == plur_seq: |
|
continue |
|
|
|
data.append([" ".join([str(s) for s in seq]), |
|
" ".join([str(s) for s in plur_seq])]) |
|
|
|
df = pd.DataFrame(data, columns=["Singular", "Plural"]) |
|
df.to_csv("hop_agreement_data.csv") |
|
|