File size: 14,605 Bytes
7332c68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
# perturb.py
# Author: Julie Kallini

# For importing utils
import sys
sys.path.append("..")

from utils_qwen import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
    GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
from glob import glob
import numpy as np
import itertools
import json
import os
import tqdm
import argparse
import pytest

MODEL_NAME = "Qwen2.5-7B"

def lines_equivalent_3pres(file1_path, file2_path):
    """Compare lines of two files after splitting them."""
    with open(file1_path, 'r') as file1, open(file2_path, 'r') as file2:
        for line1, line2 in zip(file1, file2):
            # Split each line and compare the resulting lists
            res1 = [i for i in line1.split() if int(
                i) not in (marker_sg_token, marker_pl_token)]
            res2 = [i for i in line2.split() if int(
                i) not in (marker_sg_token, marker_pl_token)]
            if res1 != res2:
                print(line1)
                print(line2)
                return False

        # Check if one file has more lines than the other
        if file1.readline() or file2.readline():
            return False

    return True


perturbation_pairs_3pres = [
    ("0tokens", "4tokens"),
    ("0tokens", "4words"),
    ("4tokens", "4words"),
]

# Yj: 针对与第三人称单数/复数相关的扰动对进行组合测试

test_data = itertools.product(  
    ["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_3pres)  # Yj: generate different pairs used in test

# Yj: 用于在测试函数中,例如 test_3pres_all_equivalent,生成各种测试组合,包括不同的扰动策略。
# Yj: 区分受影响和未受影响的测试子集,以比较扰动前后的效果。


@pytest.mark.parametrize("split, genre, perturbation_pair", test_data) # 测试函数会针对 test_data 中的每组参数运行一次
def test_3pres_all_equivalent(split, genre, perturbation_pair):     # Yj: genre these are different kinds of Corpus, which can be seen in utils.py 

    perturbation1, perturbation2 = perturbation_pair

    if split in ("100M", "10M"):
        filename = f"{genre}.train"
    elif split == "test_affected":
        filename = f"{genre}_affected.test"
    elif split == "test_unaffected":
        filename = f"{genre}_unaffected.test"                  
    elif split == "dev":
        filename = f"{genre}.dev"   # Yj: Development Set is similar to Validation Set

    path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_3pres_{perturbation1}/babylm_{split}/{filename}"
    path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_3pres_{perturbation2}/babylm_{split}/{filename}"
 
    #Yj: compare two files in two paths    
    assert lines_equivalent_3pres(path1, path2), f"File {filename} of " + \
        f"3pres_{perturbation1} and 3pres_{perturbation2} have non-equivalent lines!"


def lines_equivalent_reversal(rev_path, ident_path):
    """Compare lines of reversal file and identity file after splitting them."""
    with open(rev_path, 'r') as file1, open(ident_path, 'r') as file2:
        for line1, line2 in zip(file1, file2):
            # Split each line and compare the resulting lists
            line1_tokens = line1.split()
            line2_tokens = line2.split()

            # Get REV marker index
            marker_index = line1_tokens.index(str(marker_rev_token))

            # Make sure tokens up to and including the marker are all the same
            if line1_tokens[:marker_index+1] != line2_tokens[:marker_index+1]:
                return False
        
            # Make sure reversal of rest of string is equal to identity
            line1_tokens_rev = line1_tokens[marker_index+1:].copy()
            line1_tokens_rev.reverse()
            if line1_tokens_rev != line2_tokens[marker_index+1:]:
                return False

        # Check if one file has more lines than the other
        if file1.readline() or file2.readline():
            return False

    return True
        

perturbation_pairs_reversal = [
    ("reversal", "reversal_identity"),
]
# Yj: 针对反转扰动对进行组合测试

test_data = itertools.product(
    ["100M", "dev", "test_affected"], GENRES.keys(), perturbation_pairs_reversal)

@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
def test_reversal_all_equivalent(split, genre, perturbation_pair):

    perturbation1, perturbation2 = perturbation_pair

    if split in ("100M", "10M"):
        filename = f"{genre}.train"
    elif split == "test_affected":
        filename = f"{genre}_affected.test"
    elif split == "test_unaffected":
        filename = f"{genre}_unaffected.test"
    elif split == "dev":
        filename = f"{genre}.dev"

    path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation1}/babylm_{split}/{filename}"
    path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation2}/babylm_{split}/{filename}"

    assert lines_equivalent_reversal(path1, path2), f"File {filename} of " + \
        f"{perturbation1} and {perturbation2} have non-equivalent lines!"


def lines_equivalent_determiner_swap(det_path, ident_path):
    """Compare lines of reversal file and identity file after splitting them."""
    with open(det_path, 'r') as file1, open(ident_path, 'r') as file2:
        for line1, line2 in zip(file1, file2):
            # Split each line and compare the resulting lists
            line1_tokens = set(line1.split())
            line2_tokens = set(line2.split())
            if line1_tokens != line2_tokens:
                print(line1.split())
                print(line2.split())
                return False

        # Check if one file has more lines than the other
        if file1.readline() or file2.readline():
            return False

    return True
        

perturbation_pairs_reversal = [
    ("determiner_swap", "determiner_swap_identity"),
]
test_data = itertools.product(
    ["100M", "dev", "test_affected", "test_unaffected"], GENRES.keys(), perturbation_pairs_reversal)

@pytest.mark.parametrize("split, genre, perturbation_pair", test_data)
def test_determiner_swap_all_equivalent(split, genre, perturbation_pair):

    perturbation1, perturbation2 = perturbation_pair

    if split in ("100M", "10M"):
        filename = f"{genre}.train"
    elif split == "test_affected":
        filename = f"{genre}_affected.test"
    elif split == "test_unaffected":
        filename = f"{genre}_unaffected.test"
    elif split == "dev":
        filename = f"{genre}.dev"

    path1 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation1}/babylm_{split}/{filename}"
    path2 = f"{BABYLM_DATA_PATH}/babylm_data_perturbed_qwen/babylm_{perturbation2}/babylm_{split}/{filename}"

    assert lines_equivalent_determiner_swap(path1, path2), f"File {filename} of " + \
        f"{perturbation1} and {perturbation2} have non-equivalent lines!"


def flatten_list(l):
    """Function to flatten a nested list."""
    return list(itertools.chain.from_iterable(l))


def process_line(line):
    """
    Process a given line from the dataset, apply transformations to its sentences, 
    and categorize them into affected or unaffected based on the transformation.

    Parameters:
    - line (dict): A dictionary representing a line from the dataset, which contains 
      sentence annotations.

    Returns:
    - tuple: A tuple containing three lists:
        1. new_lines_affected (list of str): Sentences that were affected by the transformation.
        2. new_lines_unaffected (list of str): Sentences that were not affected by the transformation.

    Note:
    - The transformation functions (`perturbation_function`, `affect_function`, `filter_function`) 
      are expected to be available in the global scope.
    """

    new_lines_affected = []
    new_lines_unaffected = []
    sents_unaffected = []

    # Apply transformation to each sentence on line
    for sent in line["sent_annotations"]:   # Yj: 这处不明白为什么用annotations不用text?

        tokens = perturbation_function(sent)
        if len([tok for tok in tokens if tok not in MARKER_TOKEN_IDS]) <= 1:
            continue

        token_line = " ".join([str(tok) for tok in tokens])

        # Check if sent is affected
        if affect_function(sent):

            # Check if this affected sentence should be filtered or not
            if filter_function(sent):
                new_lines_affected.append(token_line + "\n")

        else:  # Unaffected sentences
            new_lines_unaffected.append(token_line + "\n")
            sents_unaffected.append(sent["sent_text"] + "\n")

    return new_lines_affected, new_lines_unaffected, sents_unaffected


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        prog='Perturb BabyLM dataset',
        description='Perturb BabyLM dataset by altering POS-tagged data')
    parser.add_argument('perturbation_type',
                        default='all',
                        const='all',
                        nargs='?',
                        choices=PERTURBATIONS.keys(),
                        help='Perturbation function used to transform BabyLM dataset')
    parser.add_argument('babylm_dataset',
                        default='all',
                        const='all',
                        nargs='?',
                        choices=BABYLM_SPLITS,
                        help='BabyLM dataset choice')

    # Get args
    args = parser.parse_args()

    # Load dataset (only json files containing tagged data)
    babylm_dataset = args.babylm_dataset
    json_ext = "_parsed.json"
    # babylm_data = glob(f"{BABYLM_DATA_PATH}/babylm_data/babylm_{babylm_dataset}/*{json_ext}")
    babylm_data = glob(f"babylm_data/babylm_{babylm_dataset}/*{json_ext}")
    print("babylm_data:", babylm_data)
    
    # Get perturbation, affect, and filter functions
    perturbation_function = PERTURBATIONS[args.perturbation_type]['perturbation_function']
    affect_function = PERTURBATIONS[args.perturbation_type]['affect_function']
    filter_function = PERTURBATIONS[args.perturbation_type]['filter_function']
    qwen_tokenizer = PERTURBATIONS[args.perturbation_type]['qwen_tokenizer']

    if babylm_dataset == "test":  # Yj: 为什么abylm_dataset是test? BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
        
        # Iterate over files and do transform
        for file in babylm_data:
            print(file)
            f = open(file)
            data = json.load(f)
            f.close()

            # Perturb data iteratively
            results = []
            for line in tqdm.tqdm(data):
                results.append(process_line(line))

            new_lines_affected, new_lines_unaffected, unaffected_sents = zip(
                *results)
            new_lines_affected = flatten_list(new_lines_affected)
            new_lines_unaffected = flatten_list(new_lines_unaffected)
            unaffected_sents = flatten_list(unaffected_sents)

            # Name new file
            new_file_affected = os.path.basename(
                file).replace(json_ext, "_affected.test")
            new_file_unaffected = os.path.basename(
                file).replace(json_ext, "_unaffected.test")
            file_unaffected_sents = os.path.basename(
                file).replace(json_ext, "_unaffected_sents.test")

            # Create directory
            data_write_directory = f"{BABYLM_DATA_PATH}/Qwen_perturbed_data/{MODEL_NAME}"
            directory_affected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_affected/"
            if not os.path.exists(directory_affected):
                os.makedirs(directory_affected)
            directory_unaffected = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected/"
            if not os.path.exists(directory_unaffected):
                os.makedirs(directory_unaffected)
            directory_unaffected_sents = f"{data_write_directory}/babylm_{args.perturbation_type}/babylm_test_unaffected_sents/"
            if not os.path.exists(directory_unaffected_sents):
                os.makedirs(directory_unaffected_sents)

            # Write files
            write_file(directory_affected,
                       new_file_affected, new_lines_affected)
            write_file(directory_unaffected,
                       new_file_unaffected, new_lines_unaffected)
            write_file(directory_unaffected_sents,
                       file_unaffected_sents, unaffected_sents)

    else: 
        # Yj: BABYLM_SPLITS = ['100M', '10M', 'dev', 'test', 'unittest']
        # Iterate over files and do transform
        for file in babylm_data:
            print(file)
            f = open(file)
            data = json.load(f)
            f.close()

            # Perturb data iteratively
            results = []
            for line in tqdm.tqdm(data):
                results.append(process_line(line))

            new_lines_affected, new_lines_unaffected, _ = zip(
                *results)

            new_lines_affected = flatten_list(new_lines_affected)
            new_lines_unaffected = flatten_list(new_lines_unaffected)

            # Combine affected and unaffected sentences
            new_lines = new_lines_unaffected + new_lines_affected

            # Name new file
            if babylm_dataset == "dev":
                new_file = os.path.basename(file).replace(json_ext, ".dev")
            elif babylm_dataset == 'unittest':
                new_file = os.path.basename(file).replace(json_ext, ".test")

                # Print strings for unittest
                new_lines_decoded = [qwen_tokenizer.decode(
                    [int(tok) for tok in line.split()]) + "\n" for line in new_lines]
                new_lines_with_strings = []
                for tokens, line in list(zip(new_lines, new_lines_decoded)):
                    new_lines_with_strings.append(tokens)
                    new_lines_with_strings.append(line)
                new_lines = new_lines_with_strings

            else:
                new_file = os.path.basename(file).replace(json_ext, ".train")   # '10M 100M' is training set

            # Create directory and write file 
            directory = f"{BABYLM_DATA_PATH}/Perturbed_data/{MODEL_NAME}/babylm_{args.perturbation_type}/babylm_{babylm_dataset}/"
            print("directory:", directory)
            if not os.path.exists(directory):
                os.makedirs(directory)
            write_file(directory, new_file, new_lines)