Spaces:
Runtime error
Runtime error
File size: 3,786 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter
def gt_label2entity(gt_infos):
"""Get all entities from ground truth infos.
Args:
gt_infos (list[dict]): Ground-truth information contains text and
label.
Returns:
gt_entities (list[list]): Original labeled entities in groundtruth.
[[category,start_position,end_position]]
"""
gt_entities = []
for gt_info in gt_infos:
line_entities = []
label = gt_info['label']
for key, value in label.items():
for _, places in value.items():
for place in places:
line_entities.append([key, place[0], place[1]])
gt_entities.append(line_entities)
return gt_entities
def _compute_f1(origin, found, right):
"""Calculate recall, precision, f1-score.
Args:
origin (int): Original entities in groundtruth.
found (int): Predicted entities from model.
right (int): Predicted entities that
can match to the original annotation.
Returns:
recall (float): Metric of recall.
precision (float): Metric of precision.
f1 (float): Metric of f1-score.
"""
recall = 0 if origin == 0 else (right / origin)
precision = 0 if found == 0 else (right / found)
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (
precision + recall)
return recall, precision, f1
def compute_f1_all(pred_entities, gt_entities):
"""Calculate precision, recall and F1-score for all categories.
Args:
pred_entities: The predicted entities from model.
gt_entities: The entities of ground truth file.
Returns:
class_info (dict): precision,recall, f1-score in total
and each categories.
"""
origins = []
founds = []
rights = []
for i, _ in enumerate(pred_entities):
origins.extend(gt_entities[i])
founds.extend(pred_entities[i])
rights.extend([
pre_entity for pre_entity in pred_entities[i]
if pre_entity in gt_entities[i]
])
class_info = {}
origin_counter = Counter([x[0] for x in origins])
found_counter = Counter([x[0] for x in founds])
right_counter = Counter([x[0] for x in rights])
for type_, count in origin_counter.items():
origin = count
found = found_counter.get(type_, 0)
right = right_counter.get(type_, 0)
recall, precision, f1 = _compute_f1(origin, found, right)
class_info[type_] = {
'precision': precision,
'recall': recall,
'f1-score': f1
}
origin = len(origins)
found = len(founds)
right = len(rights)
recall, precision, f1 = _compute_f1(origin, found, right)
class_info['all'] = {
'precision': precision,
'recall': recall,
'f1-score': f1
}
return class_info
def eval_ner_f1(results, gt_infos):
"""Evaluate for ner task.
Args:
results (list): Predict results of entities.
gt_infos (list[dict]): Ground-truth information which contains
text and label.
Returns:
class_info (dict): precision,recall, f1-score of total
and each catogory.
"""
assert len(results) == len(gt_infos)
gt_entities = gt_label2entity(gt_infos)
pred_entities = []
for i, gt_info in enumerate(gt_infos):
line_entities = []
for result in results[i]:
line_entities.append(result)
pred_entities.append(line_entities)
assert len(pred_entities) == len(gt_entities)
class_info = compute_f1_all(pred_entities, gt_entities)
return class_info
|