Spaces:
Runtime error
Runtime error
File size: 4,970 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# Copyright (c) OpenMMLab. All rights reserved.
import re
from difflib import SequenceMatcher
from rapidfuzz import string_metric
def cal_true_positive_char(pred, gt):
"""Calculate correct character number in prediction.
Args:
pred (str): Prediction text.
gt (str): Ground truth text.
Returns:
true_positive_char_num (int): The true positive number.
"""
all_opt = SequenceMatcher(None, pred, gt)
true_positive_char_num = 0
for opt, _, _, s2, e2 in all_opt.get_opcodes():
if opt == 'equal':
true_positive_char_num += (e2 - s2)
else:
pass
return true_positive_char_num
def count_matches(pred_texts, gt_texts):
"""Count the various match number for metric calculation.
Args:
pred_texts (list[str]): Predicted text string.
gt_texts (list[str]): Ground truth text string.
Returns:
match_res: (dict[str: int]): Match number used for
metric calculation.
"""
match_res = {
'gt_char_num': 0,
'pred_char_num': 0,
'true_positive_char_num': 0,
'gt_word_num': 0,
'match_word_num': 0,
'match_word_ignore_case': 0,
'match_word_ignore_case_symbol': 0
}
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
norm_ed_sum = 0.0
for pred_text, gt_text in zip(pred_texts, gt_texts):
if gt_text == pred_text:
match_res['match_word_num'] += 1
gt_text_lower = gt_text.lower()
pred_text_lower = pred_text.lower()
if gt_text_lower == pred_text_lower:
match_res['match_word_ignore_case'] += 1
gt_text_lower_ignore = comp.sub('', gt_text_lower)
pred_text_lower_ignore = comp.sub('', pred_text_lower)
if gt_text_lower_ignore == pred_text_lower_ignore:
match_res['match_word_ignore_case_symbol'] += 1
match_res['gt_word_num'] += 1
# normalized edit distance
edit_dist = string_metric.levenshtein(pred_text_lower_ignore,
gt_text_lower_ignore)
norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore),
len(pred_text_lower_ignore))
norm_ed_sum += norm_ed
# number to calculate char level recall & precision
match_res['gt_char_num'] += len(gt_text_lower_ignore)
match_res['pred_char_num'] += len(pred_text_lower_ignore)
true_positive_char_num = cal_true_positive_char(
pred_text_lower_ignore, gt_text_lower_ignore)
match_res['true_positive_char_num'] += true_positive_char_num
normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts))
match_res['ned'] = normalized_edit_distance
return match_res
def eval_ocr_metric(pred_texts, gt_texts):
"""Evaluate the text recognition performance with metric: word accuracy and
1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details.
Args:
pred_texts (list[str]): Text strings of prediction.
gt_texts (list[str]): Text strings of ground truth.
Returns:
eval_res (dict[str: float]): Metric dict for text recognition, include:
- word_acc: Accuracy in word level.
- word_acc_ignore_case: Accuracy in word level, ignore letter case.
- word_acc_ignore_case_symbol: Accuracy in word level, ignore
letter case and symbol. (default metric for
academic evaluation)
- char_recall: Recall in character level, ignore
letter case and symbol.
- char_precision: Precision in character level, ignore
letter case and symbol.
- 1-N.E.D: 1 - normalized_edit_distance.
"""
assert isinstance(pred_texts, list)
assert isinstance(gt_texts, list)
assert len(pred_texts) == len(gt_texts)
match_res = count_matches(pred_texts, gt_texts)
eps = 1e-8
char_recall = 1.0 * match_res['true_positive_char_num'] / (
eps + match_res['gt_char_num'])
char_precision = 1.0 * match_res['true_positive_char_num'] / (
eps + match_res['pred_char_num'])
word_acc = 1.0 * match_res['match_word_num'] / (
eps + match_res['gt_word_num'])
word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / (
eps + match_res['gt_word_num'])
word_acc_ignore_case_symbol = 1.0 * match_res[
'match_word_ignore_case_symbol'] / (
eps + match_res['gt_word_num'])
eval_res = {}
eval_res['word_acc'] = word_acc
eval_res['word_acc_ignore_case'] = word_acc_ignore_case
eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol
eval_res['char_recall'] = char_recall
eval_res['char_precision'] = char_precision
eval_res['1-N.E.D'] = 1.0 - match_res['ned']
for key, value in eval_res.items():
eval_res[key] = float('{:.4f}'.format(value))
return eval_res
|