chendl's picture
add requirements
a1d409e
raw
history blame
5.51 kB
import logging
import os
from typing import List, TextIO, Union
from conllu import parse_incr
from utils_ner import InputExample, Split, TokenClassificationTask
logger = logging.getLogger(__name__)
class NER(TokenClassificationTask):
def __init__(self, label_idx=-1):
# in NER datasets, the last column is usually reserved for NER label
self.label_idx = label_idx
def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> List[InputExample]:
if isinstance(mode, Split):
mode = mode.value
file_path = os.path.join(data_dir, f"{mode}.txt")
guid_index = 1
examples = []
with open(file_path, encoding="utf-8") as f:
words = []
labels = []
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
if words:
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
guid_index += 1
words = []
labels = []
else:
splits = line.split(" ")
words.append(splits[0])
if len(splits) > 1:
labels.append(splits[self.label_idx].replace("\n", ""))
else:
# Examples could have no label for mode = "test"
labels.append("O")
if words:
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
return examples
def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: List):
example_id = 0
for line in test_input_reader:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not preds_list[example_id]:
example_id += 1
elif preds_list[example_id]:
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
def get_labels(self, path: str) -> List[str]:
if path:
with open(path, "r") as f:
labels = f.read().splitlines()
if "O" not in labels:
labels = ["O"] + labels
return labels
else:
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
class Chunk(NER):
def __init__(self):
# in CONLL2003 dataset chunk column is second-to-last
super().__init__(label_idx=-2)
def get_labels(self, path: str) -> List[str]:
if path:
with open(path, "r") as f:
labels = f.read().splitlines()
if "O" not in labels:
labels = ["O"] + labels
return labels
else:
return [
"O",
"B-ADVP",
"B-INTJ",
"B-LST",
"B-PRT",
"B-NP",
"B-SBAR",
"B-VP",
"B-ADJP",
"B-CONJP",
"B-PP",
"I-ADVP",
"I-INTJ",
"I-LST",
"I-PRT",
"I-NP",
"I-SBAR",
"I-VP",
"I-ADJP",
"I-CONJP",
"I-PP",
]
class POS(TokenClassificationTask):
def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> List[InputExample]:
if isinstance(mode, Split):
mode = mode.value
file_path = os.path.join(data_dir, f"{mode}.txt")
guid_index = 1
examples = []
with open(file_path, encoding="utf-8") as f:
for sentence in parse_incr(f):
words = []
labels = []
for token in sentence:
words.append(token["form"])
labels.append(token["upos"])
assert len(words) == len(labels)
if words:
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
guid_index += 1
return examples
def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: List):
example_id = 0
for sentence in parse_incr(test_input_reader):
s_p = preds_list[example_id]
out = ""
for token in sentence:
out += f'{token["form"]} ({token["upos"]}|{s_p.pop(0)}) '
out += "\n"
writer.write(out)
example_id += 1
def get_labels(self, path: str) -> List[str]:
if path:
with open(path, "r") as f:
return f.read().splitlines()
else:
return [
"ADJ",
"ADP",
"ADV",
"AUX",
"CCONJ",
"DET",
"INTJ",
"NOUN",
"NUM",
"PART",
"PRON",
"PROPN",
"PUNCT",
"SCONJ",
"SYM",
"VERB",
"X",
]