File size: 3,786 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter


def gt_label2entity(gt_infos):
    """Get all entities from ground truth infos.
    Args:
        gt_infos (list[dict]): Ground-truth information contains text and
            label.
    Returns:
        gt_entities (list[list]): Original labeled entities in groundtruth.
                    [[category,start_position,end_position]]
    """
    gt_entities = []
    for gt_info in gt_infos:
        line_entities = []
        label = gt_info['label']
        for key, value in label.items():
            for _, places in value.items():
                for place in places:
                    line_entities.append([key, place[0], place[1]])
        gt_entities.append(line_entities)
    return gt_entities


def _compute_f1(origin, found, right):
    """Calculate recall, precision, f1-score.

    Args:
        origin (int): Original entities in groundtruth.
        found (int): Predicted entities from model.
        right (int): Predicted entities that
                        can match to the original annotation.
    Returns:
        recall (float): Metric of recall.
        precision (float): Metric of precision.
        f1 (float): Metric of f1-score.
    """
    recall = 0 if origin == 0 else (right / origin)
    precision = 0 if found == 0 else (right / found)
    f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (
        precision + recall)
    return recall, precision, f1


def compute_f1_all(pred_entities, gt_entities):
    """Calculate precision, recall and F1-score for all categories.

    Args:
        pred_entities: The predicted entities from model.
        gt_entities: The entities of ground truth file.
    Returns:
        class_info (dict): precision,recall, f1-score in total
                        and each categories.
    """
    origins = []
    founds = []
    rights = []
    for i, _ in enumerate(pred_entities):
        origins.extend(gt_entities[i])
        founds.extend(pred_entities[i])
        rights.extend([
            pre_entity for pre_entity in pred_entities[i]
            if pre_entity in gt_entities[i]
        ])

    class_info = {}
    origin_counter = Counter([x[0] for x in origins])
    found_counter = Counter([x[0] for x in founds])
    right_counter = Counter([x[0] for x in rights])
    for type_, count in origin_counter.items():
        origin = count
        found = found_counter.get(type_, 0)
        right = right_counter.get(type_, 0)
        recall, precision, f1 = _compute_f1(origin, found, right)
        class_info[type_] = {
            'precision': precision,
            'recall': recall,
            'f1-score': f1
        }
    origin = len(origins)
    found = len(founds)
    right = len(rights)
    recall, precision, f1 = _compute_f1(origin, found, right)
    class_info['all'] = {
        'precision': precision,
        'recall': recall,
        'f1-score': f1
    }
    return class_info


def eval_ner_f1(results, gt_infos):
    """Evaluate for ner task.

    Args:
        results (list): Predict results of entities.
        gt_infos (list[dict]): Ground-truth information which contains
                            text and label.
    Returns:
        class_info (dict): precision,recall, f1-score of total
                            and each catogory.
    """
    assert len(results) == len(gt_infos)
    gt_entities = gt_label2entity(gt_infos)
    pred_entities = []
    for i, gt_info in enumerate(gt_infos):
        line_entities = []
        for result in results[i]:
            line_entities.append(result)
        pred_entities.append(line_entities)
    assert len(pred_entities) == len(gt_entities)
    class_info = compute_f1_all(pred_entities, gt_entities)

    return class_info