mhg-parsing / parsing /src /evaluate.py
nielklug's picture
init
6ed21b9
import math
import os.path
import re
import subprocess
import tempfile
import nltk
class FScore(object):
def __init__(self, recall, precision, fscore, complete_match, tagging_accuracy=100):
self.recall = recall
self.precision = precision
self.fscore = fscore
self.complete_match = complete_match
self.tagging_accuracy = tagging_accuracy
def __str__(self):
return (
f"("
f"Recall={self.recall:.2f}, "
f"Precision={self.precision:.2f}, "
f"FScore={self.fscore:.2f}, "
f"CompleteMatch={self.complete_match:.2f}"
) + (
f", TaggingAccuracy={self.tagging_accuracy:.2f})"
if self.tagging_accuracy < 100
else ")"
)
def evalb(evalb_dir, gold_trees, predicted_trees, ref_gold_path=None):
assert os.path.exists(evalb_dir)
evalb_program_path = os.path.join(evalb_dir, "evalb")
evalb_spmrl_program_path = os.path.join(evalb_dir, "evalb_spmrl")
assert os.path.exists(evalb_program_path) or os.path.exists(
evalb_spmrl_program_path
)
if os.path.exists(evalb_program_path):
evalb_param_path = os.path.join(evalb_dir, "nk.prm")
else:
evalb_program_path = evalb_spmrl_program_path
evalb_param_path = os.path.join(evalb_dir, "spmrl.prm")
assert os.path.exists(evalb_program_path)
assert os.path.exists(evalb_param_path)
assert len(gold_trees) == len(predicted_trees)
for gold_tree, predicted_tree in zip(gold_trees, predicted_trees):
assert isinstance(gold_tree, nltk.Tree)
assert isinstance(predicted_tree, nltk.Tree)
gold_leaves = list(gold_tree.leaves())
predicted_leaves = list(predicted_tree.leaves())
assert len(gold_leaves) == len(predicted_leaves)
assert all(
gold_word == predicted_word
for gold_word, predicted_word in zip(gold_leaves, predicted_leaves)
)
temp_dir = tempfile.TemporaryDirectory(prefix="evalb-")
gold_path = os.path.join(temp_dir.name, "gold.txt")
predicted_path = os.path.join(temp_dir.name, "predicted.txt")
output_path = os.path.join(temp_dir.name, "output.txt")
with open(gold_path, "w") as outfile:
if ref_gold_path is None:
for tree in gold_trees:
outfile.write("{}\n".format(tree.pformat(margin=1e100)))
else:
# For the SPMRL dataset our data loader performs some modifications
# (like stripping morphological features), so we compare to the
# raw gold file to be certain that we haven't spoiled the evaluation
# in some way.
with open(ref_gold_path) as goldfile:
outfile.write(goldfile.read())
with open(predicted_path, "w") as outfile:
for tree in predicted_trees:
outfile.write("{}\n".format(tree.pformat(margin=1e100)))
command = "{} -p {} {} {} > {}".format(
evalb_program_path,
evalb_param_path,
gold_path,
predicted_path,
output_path,
)
subprocess.run(command, shell=True)
# with open(output_path) as f:
# print('contents of output file:')
# for line in f:
# print(line)
fscore = FScore(math.nan, math.nan, math.nan, math.nan)
with open(output_path) as infile:
for line in infile:
match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line)
if match:
fscore.recall = float(match.group(1))
match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line)
if match:
fscore.precision = float(match.group(1))
match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line)
if match:
fscore.fscore = float(match.group(1))
match = re.match(r"Complete match\s+=\s+(\d+\.\d+)", line)
if match:
fscore.complete_match = float(match.group(1))
match = re.match(r"Tagging accuracy\s+=\s+(\d+\.\d+)", line)
if match:
fscore.tagging_accuracy = float(match.group(1))
break
success = (
not math.isnan(fscore.fscore) or fscore.recall == 0.0 or fscore.precision == 0.0
)
if success:
temp_dir.cleanup()
else:
print("Error reading EVALB results.")
print("Gold path: {}".format(gold_path))
print("Predicted path: {}".format(predicted_path))
print("Output path: {}".format(output_path))
return fscore