|
|
|
|
|
|
|
|
|
import sys |
|
sys.path.append("..") |
|
|
|
import pytest |
|
import glob |
|
import tqdm |
|
import os |
|
import argparse |
|
import stanza |
|
import json |
|
from transformers import AutoTokenizer |
|
|
|
|
|
def chunk_text(text, tokenizer, max_length=512): |
|
tokens = tokenizer(text)['input_ids'] |
|
chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)] |
|
return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks] |
|
|
|
|
|
test_all_files = sorted(glob.glob("babylm_data/babylm_*/*")) |
|
test_original_files = [f for f in test_all_files if ".json" not in f] |
|
test_json_files = [f for f in test_all_files if "_parsed.json" in f] |
|
test_cases = list(zip(test_original_files, test_json_files)) |
|
|
|
@pytest.mark.parametrize("original_file, json_file", test_cases) |
|
def test_equivalent_lines(original_file, json_file): |
|
|
|
|
|
original_file = open(original_file) |
|
original_data = "".join(original_file.readlines()) |
|
original_data = "".join(original_data.split()) |
|
|
|
json_file = open(json_file) |
|
json_lines = json.load(json_file) |
|
json_data = "" |
|
for line in json_lines: |
|
for sent in line["sent_annotations"]: |
|
json_data += sent["sent_text"] |
|
json_data = "".join(json_data.split()) |
|
|
|
|
|
assert (original_data == json_data) |
|
|
|
|
|
def __get_constituency_parse(sent, nlp): |
|
|
|
|
|
try: |
|
parse_doc = nlp(sent.text) |
|
except: |
|
return None |
|
|
|
|
|
parse_trees = [str(sent.constituency) for sent in parse_doc.sentences] |
|
|
|
|
|
constituency_parse = "(ROOT " + " ".join(parse_trees) + ")" |
|
return constituency_parse |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser( |
|
prog='Tag BabyLM dataset', |
|
description='Tag BabyLM dataset using Stanza') |
|
parser.add_argument('path', type=argparse.FileType('r'), |
|
nargs='+', help="Path to file(s)") |
|
parser.add_argument('-p', '--parse', action='store_true', |
|
help="Include constituency parse") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
nlp1 = stanza.Pipeline( |
|
lang='en', |
|
processors='tokenize, pos, lemma', |
|
package="default_accurate", |
|
use_gpu=True) |
|
|
|
|
|
if args.parse: |
|
nlp2 = stanza.Pipeline(lang='en', |
|
processors='tokenize,pos,constituency', |
|
package="default_accurate", |
|
use_gpu=True) |
|
|
|
BATCH_SIZE = 100 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
|
|
for file in args.path: |
|
|
|
print(file.name) |
|
lines = file.readlines() |
|
|
|
|
|
print("Concatenating lines...") |
|
lines = [l.strip() for l in lines] |
|
line_batches = [lines[i:i + BATCH_SIZE] |
|
for i in range(0, len(lines), BATCH_SIZE)] |
|
text_batches = [" ".join(l) for l in line_batches] |
|
|
|
|
|
line_annotations = [] |
|
print("Segmenting and parsing text batches...") |
|
for text in tqdm.tqdm(text_batches): |
|
|
|
text_chunks = chunk_text(text, tokenizer) |
|
|
|
|
|
for chunk in text_chunks: |
|
|
|
doc = nlp1(chunk) |
|
|
|
|
|
sent_annotations = [] |
|
for sent in doc.sentences: |
|
|
|
|
|
word_annotations = [] |
|
for token, word in zip(sent.tokens, sent.words): |
|
wa = { |
|
'id': word.id, |
|
'text': word.text, |
|
'lemma': word.lemma, |
|
'upos': word.upos, |
|
'xpos': word.xpos, |
|
'feats': word.feats, |
|
'start_char': token.start_char, |
|
'end_char': token.end_char |
|
} |
|
word_annotations.append(wa) |
|
|
|
|
|
if args.parse: |
|
constituency_parse = __get_constituency_parse(sent, nlp2) |
|
sa = { |
|
'sent_text': sent.text, |
|
'constituency_parse': constituency_parse, |
|
'word_annotations': word_annotations, |
|
} |
|
else: |
|
sa = { |
|
'sent_text': sent.text, |
|
'word_annotations': word_annotations, |
|
} |
|
sent_annotations.append(sa) |
|
|
|
la = { |
|
'sent_annotations': sent_annotations |
|
} |
|
line_annotations.append(la) |
|
|
|
|
|
print("Writing JSON outfile...") |
|
ext = '_parsed.json' if args.parse else '.json' |
|
json_filename = os.path.splitext(file.name)[0] + ext |
|
with open(json_filename, "w") as outfile: |
|
json.dump(line_annotations, outfile, indent=4) |
|
|