tomofi's picture
Add application file
2366e36
raw
history blame
No virus
4.46 kB
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
import torch
import mmocr.utils as utils
from mmocr.models.builder import CONVERTORS
from .base import BaseConvertor
@CONVERTORS.register_module()
class SegConvertor(BaseConvertor):
"""Convert between text, index and tensor for segmentation based pipeline.
Args:
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
dict_file (None|str): Character dict file path. If not none, the
file is of higher priority than dict_type.
dict_list (None|list[str]): Character list. If not none, the list
is of higher priority than dict_type, but lower than dict_file.
with_unknown (bool): If True, add `UKN` token to class.
lower (bool): If True, convert original string to lower case.
"""
def __init__(self,
dict_type='DICT36',
dict_file=None,
dict_list=None,
with_unknown=True,
lower=False,
**kwargs):
super().__init__(dict_type, dict_file, dict_list)
assert isinstance(with_unknown, bool)
assert isinstance(lower, bool)
self.with_unknown = with_unknown
self.lower = lower
self.update_dict()
def update_dict(self):
# background
self.idx2char.insert(0, '<BG>')
# unknown
self.unknown_idx = None
if self.with_unknown:
self.idx2char.append('<UKN>')
self.unknown_idx = len(self.idx2char) - 1
# update char2idx
self.char2idx = {}
for idx, char in enumerate(self.idx2char):
self.char2idx[char] = idx
def tensor2str(self, output, img_metas=None):
"""Convert model output tensor to string labels.
Args:
output (tensor): Model outputs with size: N * C * H * W
img_metas (list[dict]): Each dict contains one image info.
Returns:
texts (list[str]): Decoded text labels.
scores (list[list[float]]): Decoded chars scores.
"""
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == output.size(0)
texts, scores = [], []
for b in range(output.size(0)):
seg_pred = output[b].detach()
valid_width = int(
output.size(-1) * img_metas[b]['valid_ratio'] + 1)
seg_res = torch.argmax(
seg_pred[:, :, :valid_width],
dim=0).cpu().numpy().astype(np.int32)
seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8)
_, labels, stats, centroids = cv2.connectedComponentsWithStats(
seg_thr)
component_num = stats.shape[0]
all_res = []
for i in range(component_num):
temp_loc = (labels == i)
temp_value = seg_res[temp_loc]
temp_center = centroids[i]
temp_max_num = 0
temp_max_cls = -1
temp_total_num = 0
for c in range(len(self.idx2char)):
c_num = np.sum(temp_value == c)
temp_total_num += c_num
if c_num > temp_max_num:
temp_max_num = c_num
temp_max_cls = c
if temp_max_cls == 0:
continue
temp_max_score = 1.0 * temp_max_num / temp_total_num
all_res.append(
[temp_max_cls, temp_center, temp_max_num, temp_max_score])
all_res = sorted(all_res, key=lambda s: s[1][0])
chars, char_scores = [], []
for res in all_res:
temp_area = res[2]
if temp_area < 20:
continue
temp_char_index = res[0]
if temp_char_index >= len(self.idx2char):
temp_char = ''
elif temp_char_index <= 0:
temp_char = ''
elif temp_char_index == self.unknown_idx:
temp_char = ''
else:
temp_char = self.idx2char[temp_char_index]
chars.append(temp_char)
char_scores.append(res[3])
text = ''.join(chars)
texts.append(text)
scores.append(char_scores)
return texts, scores