# Copyright (c) OpenMMLab. All rights reserved. from collections import Counter def gt_label2entity(gt_infos): """Get all entities from ground truth infos. Args: gt_infos (list[dict]): Ground-truth information contains text and label. Returns: gt_entities (list[list]): Original labeled entities in groundtruth. [[category,start_position,end_position]] """ gt_entities = [] for gt_info in gt_infos: line_entities = [] label = gt_info['label'] for key, value in label.items(): for _, places in value.items(): for place in places: line_entities.append([key, place[0], place[1]]) gt_entities.append(line_entities) return gt_entities def _compute_f1(origin, found, right): """Calculate recall, precision, f1-score. Args: origin (int): Original entities in groundtruth. found (int): Predicted entities from model. right (int): Predicted entities that can match to the original annotation. Returns: recall (float): Metric of recall. precision (float): Metric of precision. f1 (float): Metric of f1-score. """ recall = 0 if origin == 0 else (right / origin) precision = 0 if found == 0 else (right / found) f1 = 0. if recall + precision == 0 else (2 * precision * recall) / ( precision + recall) return recall, precision, f1 def compute_f1_all(pred_entities, gt_entities): """Calculate precision, recall and F1-score for all categories. Args: pred_entities: The predicted entities from model. gt_entities: The entities of ground truth file. Returns: class_info (dict): precision,recall, f1-score in total and each categories. """ origins = [] founds = [] rights = [] for i, _ in enumerate(pred_entities): origins.extend(gt_entities[i]) founds.extend(pred_entities[i]) rights.extend([ pre_entity for pre_entity in pred_entities[i] if pre_entity in gt_entities[i] ]) class_info = {} origin_counter = Counter([x[0] for x in origins]) found_counter = Counter([x[0] for x in founds]) right_counter = Counter([x[0] for x in rights]) for type_, count in origin_counter.items(): origin = count found = found_counter.get(type_, 0) right = right_counter.get(type_, 0) recall, precision, f1 = _compute_f1(origin, found, right) class_info[type_] = { 'precision': precision, 'recall': recall, 'f1-score': f1 } origin = len(origins) found = len(founds) right = len(rights) recall, precision, f1 = _compute_f1(origin, found, right) class_info['all'] = { 'precision': precision, 'recall': recall, 'f1-score': f1 } return class_info def eval_ner_f1(results, gt_infos): """Evaluate for ner task. Args: results (list): Predict results of entities. gt_infos (list[dict]): Ground-truth information which contains text and label. Returns: class_info (dict): precision,recall, f1-score of total and each catogory. """ assert len(results) == len(gt_infos) gt_entities = gt_label2entity(gt_infos) pred_entities = [] for i, gt_info in enumerate(gt_infos): line_entities = [] for result in results[i]: line_entities.append(result) pred_entities.append(line_entities) assert len(pred_entities) == len(gt_entities) class_info = compute_f1_all(pred_entities, gt_entities) return class_info