File size: 14,605 Bytes
7332c68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 |
# perturb.py
# Author: Julie Kallini
# For importing utils
import sys
sys.path.append("..")
from utils_qwen import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
from glob import glob
import numpy as np
import itertools
import json
import os
import tqdm
import argparse
import pytest
MODEL_NAME = "Qwen2.5-7B"
def lines_equivalent_3pres(file1_path, file2_path):
"""Compare lines of two files after splitting them."""
with open(file1_path, 'r') as file1, open(file2_path, 'r') as file2:
for line1, line2 in zip(file1, file2):
# Split each line and compare the resulting lists
res1 = [i for i in line1.split() if int(
i) not in (marker_sg_token, marker_pl_token)]
res2 = [i for i in line2.split() if int(
i) not in (marker_sg_token, marker_pl_token)]
if res1 != res2:
print(line1)
print(line2)
return False
# Check if one file has more lines than the other
if file1.readline() or file2.readline():
return False
return True
perturbation_pairs_3pres = [
("0tokens", "4tokens"),
("0tokens", "4words"),
("4tokens", "4words"),
]
# Yj: 针对与第三人称单数/复数相关的扰动对进行组合测试
test_data = itertools.product(
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_3pres) # Yj: generate different pairs used in test
# Yj: 用于在测试函数中,例如 test_3pres_all_equivalent,生成各种测试组合,包括不同的扰动策略。
# Yj: 区分受影响和未受影响的测试子集,以比较扰动前后的效果。
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data) # 测试函数会针对 test_data 中的每组参数运行一次
def test_3pres_all_equivalent(split, genre, perturbation_pair): # Yj: genre these are different kinds of Corpus, which can be seen in utils.py
perturbation1, perturbation2 = perturbation_pair
if split in ("100M", "10M"):
filename = f"{genre}.train"
elif split == "test_affected":
filename = f"{genre}_affected.test"
elif split == "test_unaffected":
filename = f"{genre}_unaffected.test"
elif split == "dev":
filename = f"{genre}.dev" # Yj: Development Set is similar to Validation Set
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_3pres_{perturbation1}/babylm_{split}/{filename}"
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_3pres_{perturbation2}/babylm_{split}/{filename}"
#Yj: compare two files in two paths
assert lines_equivalent_3pres(path1, path2), f"File {filename} of " + \
f"3pres_{perturbation1} and 3pres_{perturbation2} have non-equivalent lines!"
def lines_equivalent_reversal(rev_path, ident_path):
"""Compare lines of reversal file and identity file after splitting them."""
with open(rev_path, 'r') as file1, open(ident_path, 'r') as file2:
for line1, line2 in zip(file1, file2):
# Split each line and compare the resulting lists
line1_tokens = line1.split()
line2_tokens = line2.split()
# Get REV marker index
marker_index = line1_tokens.index(str(marker_rev_token))
# Make sure tokens up to and including the marker are all the same
if line1_tokens[:marker_index+1] != line2_tokens[:marker_index+1]:
return False
# Make sure reversal of rest of string is equal to identity
line1_tokens_rev = line1_tokens[marker_index+1:].copy()
line1_tokens_rev.reverse()
if line1_tokens_rev != line2_tokens[marker_index+1:]:
return False
# Check if one file has more lines than the other
if file1.readline() or file2.readline():
return False
return True
perturbation_pairs_reversal = [
("reversal", "reversal_identity"),
]
# Yj: 针对反转扰动对进行组合测试
test_data = itertools.product(
["100M", "dev", "test_affected"], GENRES.keys(), perturbation_pairs_reversal)
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
def test_reversal_all_equivalent(split, genre, perturbation_pair):
perturbation1, perturbation2 = perturbation_pair
if split in ("100M", "10M"):
filename = f"{genre}.train"
elif split == "test_affected":
filename = f"{genre}_affected.test"
elif split == "test_unaffected":
filename = f"{genre}_unaffected.test"
elif split == "dev":
filename = f"{genre}.dev"
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation1}/babylm_{split}/{filename}"
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation2}/babylm_{split}/{filename}"
assert lines_equivalent_reversal(path1, path2), f"File {filename} of " + \
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
def lines_equivalent_determiner_swap(det_path, ident_path):
"""Compare lines of reversal file and identity file after splitting them."""
with open(det_path, 'r') as file1, open(ident_path, 'r') as file2:
for line1, line2 in zip(file1, file2):
# Split each line and compare the resulting lists
line1_tokens = set(line1.split())
line2_tokens = set(line2.split())
if line1_tokens != line2_tokens:
print(line1.split())
print(line2.split())
return False
# Check if one file has more lines than the other
if file1.readline() or file2.readline():
return False
return True
perturbation_pairs_reversal = [
("determiner_swap", "determiner_swap_identity"),
]
test_data = itertools.product(
["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_reversal)
@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
def test_determiner_swap_all_equivalent(split, genre, perturbation_pair):
perturbation1, perturbation2 = perturbation_pair
if split in ("100M", "10M"):
filename = f"{genre}.train"
elif split == "test_affected":
filename = f"{genre}_affected.test"
elif split == "test_unaffected":
filename = f"{genre}_unaffected.test"
elif split == "dev":
filename = f"{genre}.dev"
path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation1}/babylm_{split}/{filename}"
path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation2}/babylm_{split}/{filename}"
assert lines_equivalent_determiner_swap(path1, path2), f"File {filename} of " + \
f"{perturbation1} and {perturbation2} have non-equivalent lines!"
def flatten_list(l):
"""Function to flatten a nested list."""
return list(itertools.chain.from_iterable(l))
def process_line(line):
"""
Process a given line from the dataset, apply transformations to its sentences,
and categorize them into affected or unaffected based on the transformation.
Parameters:
- line (dict): A dictionary representing a line from the dataset, which contains
sentence annotations.
Returns:
- tuple: A tuple containing three lists:
1. new_lines_affected (list of str): Sentences that were affected by the transformation.
2. new_lines_unaffected (list of str): Sentences that were not affected by the transformation.
Note:
- The transformation functions (`perturbation_function`, `affect_function`, `filter_function`)
are expected to be available in the global scope.
"""
new_lines_affected = []
new_lines_unaffected = []
sents_unaffected = []
# Apply transformation to each sentence on line
for sent in line["sent_annotations"]: # Yj: 这处不明白为什么用annotations不用text?
tokens = perturbation_function(sent)
if len([tok for tok in tokens if tok not in MARKER_TOKEN_IDS]) <= 1:
continue
token_line = " ".join([str(tok) for tok in tokens])
# Check if sent is affected
if affect_function(sent):
# Check if this affected sentence should be filtered or not
if filter_function(sent):
new_lines_affected.append(token_line + "\n")
else: # Unaffected sentences
new_lines_unaffected.append(token_line + "\n")
sents_unaffected.append(sent["sent_text"] + "\n")
return new_lines_affected, new_lines_unaffected, sents_unaffected
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='Perturb BabyLM dataset',
description='Perturb BabyLM dataset by altering POS-tagged data')
parser.add_argument('perturbation_type',
default='all',
const='all',
nargs='?',
choices=PERTURBATIONS.keys(),
help='Perturbation function used to transform BabyLM dataset')
parser.add_argument('babylm_dataset',
default='all',
const='all',
nargs='?',
choices=BABYLM_SPLITS,
help='BabyLM dataset choice')
# Get args
args = parser.parse_args()
# Load dataset (only json files containing tagged data)
babylm_dataset = args.babylm_dataset
json_ext = "_parsed.json"
# babylm_data = glob(f"{BABYLM_DATA_PATH}/babylm_data/babylm_{babylm_dataset}/*{json_ext}")
babylm_data = glob(f"babylm_data/babylm_{babylm_dataset}/*{json_ext}")
print("babylm_data:", babylm_data)
# Get perturbation, affect, and filter functions
perturbation_function = PERTURBATIONS[args.perturbation_type]['perturbation_function']
affect_function = PERTURBATIONS[args.perturbation_type]['affect_function']
filter_function = PERTURBATIONS[args.perturbation_type]['filter_function']
qwen_tokenizer = PERTURBATIONS[args.perturbation_type]['qwen_tokenizer']
if babylm_dataset == "test": # Yj: 为什么abylm_dataset是test? BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
# Iterate over files and do transform
for file in babylm_data:
print(file)
f = open(file)
data = json.load(f)
f.close()
# Perturb data iteratively
results = []
for line in tqdm.tqdm(data):
results.append(process_line(line))
new_lines_affected, new_lines_unaffected, unaffected_sents = zip(
*results)
new_lines_affected = flatten_list(new_lines_affected)
new_lines_unaffected = flatten_list(new_lines_unaffected)
unaffected_sents = flatten_list(unaffected_sents)
# Name new file
new_file_affected = os.path.basename(
file).replace(json_ext, "_affected.test")
new_file_unaffected = os.path.basename(
file).replace(json_ext, "_unaffected.test")
file_unaffected_sents = os.path.basename(
file).replace(json_ext, "_unaffected_sents.test")
# Create directory
data_write_directory = f"{BABYLM_DATA_PATH}/Qwen_perturbed_data/{MODEL_NAME}"
directory_affected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_affected/"
if not os.path.exists(directory_affected):
os.makedirs(directory_affected)
directory_unaffected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected/"
if not os.path.exists(directory_unaffected):
os.makedirs(directory_unaffected)
directory_unaffected_sents = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected_sents/"
if not os.path.exists(directory_unaffected_sents):
os.makedirs(directory_unaffected_sents)
# Write files
write_file(directory_affected,
new_file_affected, new_lines_affected)
write_file(directory_unaffected,
new_file_unaffected, new_lines_unaffected)
write_file(directory_unaffected_sents,
file_unaffected_sents, unaffected_sents)
else:
# Yj: BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
# Iterate over files and do transform
for file in babylm_data:
print(file)
f = open(file)
data = json.load(f)
f.close()
# Perturb data iteratively
results = []
for line in tqdm.tqdm(data):
results.append(process_line(line))
new_lines_affected, new_lines_unaffected, _ = zip(
*results)
new_lines_affected = flatten_list(new_lines_affected)
new_lines_unaffected = flatten_list(new_lines_unaffected)
# Combine affected and unaffected sentences
new_lines = new_lines_unaffected + new_lines_affected
# Name new file
if babylm_dataset == "dev":
new_file = os.path.basename(file).replace(json_ext, ".dev")
elif babylm_dataset == 'unittest':
new_file = os.path.basename(file).replace(json_ext, ".test")
# Print strings for unittest
new_lines_decoded = [qwen_tokenizer.decode(
[int(tok) for tok in line.split()]) + "\n" for line in new_lines]
new_lines_with_strings = []
for tokens, line in list(zip(new_lines, new_lines_decoded)):
new_lines_with_strings.append(tokens)
new_lines_with_strings.append(line)
new_lines = new_lines_with_strings
else:
new_file = os.path.basename(file).replace(json_ext, ".train") # '10M 100M' is training set
# Create directory and write file
directory = f"{BABYLM_DATA_PATH}/Perturbed_data/{MODEL_NAME}/babylm_{args.perturbation_type}/babylm_{babylm_dataset}/"
print("directory:", directory)
if not os.path.exists(directory):
os.makedirs(directory)
write_file(directory, new_file, new_lines) |