MMOCR / mmocr /core /evaluation /ner_metric.py
tomofi's picture
Add application file
2366e36
raw
history blame
3.79 kB
# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter
def gt_label2entity(gt_infos):
"""Get all entities from ground truth infos.
Args:
gt_infos (list[dict]): Ground-truth information contains text and
label.
Returns:
gt_entities (list[list]): Original labeled entities in groundtruth.
[[category,start_position,end_position]]
"""
gt_entities = []
for gt_info in gt_infos:
line_entities = []
label = gt_info['label']
for key, value in label.items():
for _, places in value.items():
for place in places:
line_entities.append([key, place[0], place[1]])
gt_entities.append(line_entities)
return gt_entities
def _compute_f1(origin, found, right):
"""Calculate recall, precision, f1-score.
Args:
origin (int): Original entities in groundtruth.
found (int): Predicted entities from model.
right (int): Predicted entities that
can match to the original annotation.
Returns:
recall (float): Metric of recall.
precision (float): Metric of precision.
f1 (float): Metric of f1-score.
"""
recall = 0 if origin == 0 else (right / origin)
precision = 0 if found == 0 else (right / found)
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (
precision + recall)
return recall, precision, f1
def compute_f1_all(pred_entities, gt_entities):
"""Calculate precision, recall and F1-score for all categories.
Args:
pred_entities: The predicted entities from model.
gt_entities: The entities of ground truth file.
Returns:
class_info (dict): precision,recall, f1-score in total
and each categories.
"""
origins = []
founds = []
rights = []
for i, _ in enumerate(pred_entities):
origins.extend(gt_entities[i])
founds.extend(pred_entities[i])
rights.extend([
pre_entity for pre_entity in pred_entities[i]
if pre_entity in gt_entities[i]
])
class_info = {}
origin_counter = Counter([x[0] for x in origins])
found_counter = Counter([x[0] for x in founds])
right_counter = Counter([x[0] for x in rights])
for type_, count in origin_counter.items():
origin = count
found = found_counter.get(type_, 0)
right = right_counter.get(type_, 0)
recall, precision, f1 = _compute_f1(origin, found, right)
class_info[type_] = {
'precision': precision,
'recall': recall,
'f1-score': f1
}
origin = len(origins)
found = len(founds)
right = len(rights)
recall, precision, f1 = _compute_f1(origin, found, right)
class_info['all'] = {
'precision': precision,
'recall': recall,
'f1-score': f1
}
return class_info
def eval_ner_f1(results, gt_infos):
"""Evaluate for ner task.
Args:
results (list): Predict results of entities.
gt_infos (list[dict]): Ground-truth information which contains
text and label.
Returns:
class_info (dict): precision,recall, f1-score of total
and each catogory.
"""
assert len(results) == len(gt_infos)
gt_entities = gt_label2entity(gt_infos)
pred_entities = []
for i, gt_info in enumerate(gt_infos):
line_entities = []
for result in results[i]:
line_entities.append(result)
pred_entities.append(line_entities)
assert len(pred_entities) == len(gt_entities)
class_info = compute_f1_all(pred_entities, gt_entities)
return class_info