File size: 5,908 Bytes
7332c68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# tag.py
# Author: Julie Kallini
# For importing utils
import sys
sys.path.append("..")
import pytest
import glob
import tqdm
import os
import argparse
import stanza
import json
from transformers import AutoTokenizer
# Define the function to chunk text
def chunk_text(text, tokenizer, max_length=512):
tokens = tokenizer(text)['input_ids']
chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
# Test case for checking equivalence of original and parsed files
test_all_files = sorted(glob.glob("babylm_data/babylm_*/*"))
test_original_files = [f for f in test_all_files if ".json" not in f]
test_json_files = [f for f in test_all_files if "_parsed.json" in f]
test_cases = list(zip(test_original_files, test_json_files))
@pytest.mark.parametrize("original_file, json_file", test_cases)
def test_equivalent_lines(original_file, json_file):
# Read lines of file and remove all whitespace
original_file = open(original_file)
original_data = "".join(original_file.readlines())
original_data = "".join(original_data.split())
json_file = open(json_file)
json_lines = json.load(json_file)
json_data = ""
for line in json_lines:
for sent in line["sent_annotations"]:
json_data += sent["sent_text"]
json_data = "".join(json_data.split())
# Test equivalence
assert (original_data == json_data)
# Constituency parsing function
def __get_constituency_parse(sent, nlp):
# Try parsing the doc
try:
parse_doc = nlp(sent.text)
except:
return None
# Get set of constituency parse trees
parse_trees = [str(sent.constituency) for sent in parse_doc.sentences]
# Join parse trees and add ROOT
constituency_parse = "(ROOT " + " ".join(parse_trees) + ")"
return constituency_parse
# Main function
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='Tag BabyLM dataset',
description='Tag BabyLM dataset using Stanza')
parser.add_argument('path', type=argparse.FileType('r'),
nargs='+', help="Path to file(s)")
parser.add_argument('-p', '--parse', action='store_true',
help="Include constituency parse")
# Get args
args = parser.parse_args()
# Init Stanza NLP tools
nlp1 = stanza.Pipeline(
lang='en',
processors='tokenize, pos, lemma',
package="default_accurate",
use_gpu=True)
# If constituency parse is needed, init second Stanza parser
if args.parse:
nlp2 = stanza.Pipeline(lang='en',
processors='tokenize,pos,constituency',
package="default_accurate",
use_gpu=True)
BATCH_SIZE = 100
# Tokenizer for splitting long text
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Iterate over BabyLM files
for file in args.path:
print(file.name)
lines = file.readlines()
# Strip lines and join text
print("Concatenating lines...")
lines = [l.strip() for l in lines]
line_batches = [lines[i:i + BATCH_SIZE]
for i in range(0, len(lines), BATCH_SIZE)]
text_batches = [" ".join(l) for l in line_batches]
# Iterate over lines in file and track annotations
line_annotations = []
print("Segmenting and parsing text batches...")
for text in tqdm.tqdm(text_batches):
# Split the text into chunks if it exceeds the max length
text_chunks = chunk_text(text, tokenizer)
# Iterate over each chunk
for chunk in text_chunks:
# Tokenize text with stanza
doc = nlp1(chunk)
# Iterate over sentences in the line and track annotations
sent_annotations = []
for sent in doc.sentences:
# Iterate over words in the sentence and track annotations
word_annotations = []
for token, word in zip(sent.tokens, sent.words):
wa = {
'id': word.id,
'text': word.text,
'lemma': word.lemma,
'upos': word.upos,
'xpos': word.xpos,
'feats': word.feats,
'start_char': token.start_char,
'end_char': token.end_char
}
word_annotations.append(wa) # Track word annotation
# Get constituency parse if needed
if args.parse:
constituency_parse = __get_constituency_parse(sent, nlp2)
sa = {
'sent_text': sent.text,
'constituency_parse': constituency_parse,
'word_annotations': word_annotations,
}
else:
sa = {
'sent_text': sent.text,
'word_annotations': word_annotations,
}
sent_annotations.append(sa) # Track sent annotation
la = {
'sent_annotations': sent_annotations
}
line_annotations.append(la) # Track line annotation
# Write annotations to file as a JSON
print("Writing JSON outfile...")
ext = '_parsed.json' if args.parse else '.json'
json_filename = os.path.splitext(file.name)[0] + ext
with open(json_filename, "w") as outfile:
json.dump(line_annotations, outfile, indent=4)
|