Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from collections import Counter | |
def gt_label2entity(gt_infos): | |
"""Get all entities from ground truth infos. | |
Args: | |
gt_infos (list[dict]): Ground-truth information contains text and | |
label. | |
Returns: | |
gt_entities (list[list]): Original labeled entities in groundtruth. | |
[[category,start_position,end_position]] | |
""" | |
gt_entities = [] | |
for gt_info in gt_infos: | |
line_entities = [] | |
label = gt_info['label'] | |
for key, value in label.items(): | |
for _, places in value.items(): | |
for place in places: | |
line_entities.append([key, place[0], place[1]]) | |
gt_entities.append(line_entities) | |
return gt_entities | |
def _compute_f1(origin, found, right): | |
"""Calculate recall, precision, f1-score. | |
Args: | |
origin (int): Original entities in groundtruth. | |
found (int): Predicted entities from model. | |
right (int): Predicted entities that | |
can match to the original annotation. | |
Returns: | |
recall (float): Metric of recall. | |
precision (float): Metric of precision. | |
f1 (float): Metric of f1-score. | |
""" | |
recall = 0 if origin == 0 else (right / origin) | |
precision = 0 if found == 0 else (right / found) | |
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / ( | |
precision + recall) | |
return recall, precision, f1 | |
def compute_f1_all(pred_entities, gt_entities): | |
"""Calculate precision, recall and F1-score for all categories. | |
Args: | |
pred_entities: The predicted entities from model. | |
gt_entities: The entities of ground truth file. | |
Returns: | |
class_info (dict): precision,recall, f1-score in total | |
and each categories. | |
""" | |
origins = [] | |
founds = [] | |
rights = [] | |
for i, _ in enumerate(pred_entities): | |
origins.extend(gt_entities[i]) | |
founds.extend(pred_entities[i]) | |
rights.extend([ | |
pre_entity for pre_entity in pred_entities[i] | |
if pre_entity in gt_entities[i] | |
]) | |
class_info = {} | |
origin_counter = Counter([x[0] for x in origins]) | |
found_counter = Counter([x[0] for x in founds]) | |
right_counter = Counter([x[0] for x in rights]) | |
for type_, count in origin_counter.items(): | |
origin = count | |
found = found_counter.get(type_, 0) | |
right = right_counter.get(type_, 0) | |
recall, precision, f1 = _compute_f1(origin, found, right) | |
class_info[type_] = { | |
'precision': precision, | |
'recall': recall, | |
'f1-score': f1 | |
} | |
origin = len(origins) | |
found = len(founds) | |
right = len(rights) | |
recall, precision, f1 = _compute_f1(origin, found, right) | |
class_info['all'] = { | |
'precision': precision, | |
'recall': recall, | |
'f1-score': f1 | |
} | |
return class_info | |
def eval_ner_f1(results, gt_infos): | |
"""Evaluate for ner task. | |
Args: | |
results (list): Predict results of entities. | |
gt_infos (list[dict]): Ground-truth information which contains | |
text and label. | |
Returns: | |
class_info (dict): precision,recall, f1-score of total | |
and each catogory. | |
""" | |
assert len(results) == len(gt_infos) | |
gt_entities = gt_label2entity(gt_infos) | |
pred_entities = [] | |
for i, gt_info in enumerate(gt_infos): | |
line_entities = [] | |
for result in results[i]: | |
line_entities.append(result) | |
pred_entities.append(line_entities) | |
assert len(pred_entities) == len(gt_entities) | |
class_info = compute_f1_all(pred_entities, gt_entities) | |
return class_info | |