# Copyright (c) OpenMMLab. All rights reserved. import cv2 import numpy as np import torch import mmocr.utils as utils from mmocr.models.builder import CONVERTORS from .base import BaseConvertor @CONVERTORS.register_module() class SegConvertor(BaseConvertor): """Convert between text, index and tensor for segmentation based pipeline. Args: dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'. dict_file (None|str): Character dict file path. If not none, the file is of higher priority than dict_type. dict_list (None|list[str]): Character list. If not none, the list is of higher priority than dict_type, but lower than dict_file. with_unknown (bool): If True, add `UKN` token to class. lower (bool): If True, convert original string to lower case. """ def __init__(self, dict_type='DICT36', dict_file=None, dict_list=None, with_unknown=True, lower=False, **kwargs): super().__init__(dict_type, dict_file, dict_list) assert isinstance(with_unknown, bool) assert isinstance(lower, bool) self.with_unknown = with_unknown self.lower = lower self.update_dict() def update_dict(self): # background self.idx2char.insert(0, '') # unknown self.unknown_idx = None if self.with_unknown: self.idx2char.append('') self.unknown_idx = len(self.idx2char) - 1 # update char2idx self.char2idx = {} for idx, char in enumerate(self.idx2char): self.char2idx[char] = idx def tensor2str(self, output, img_metas=None): """Convert model output tensor to string labels. Args: output (tensor): Model outputs with size: N * C * H * W img_metas (list[dict]): Each dict contains one image info. Returns: texts (list[str]): Decoded text labels. scores (list[list[float]]): Decoded chars scores. """ assert utils.is_type_list(img_metas, dict) assert len(img_metas) == output.size(0) texts, scores = [], [] for b in range(output.size(0)): seg_pred = output[b].detach() valid_width = int( output.size(-1) * img_metas[b]['valid_ratio'] + 1) seg_res = torch.argmax( seg_pred[:, :, :valid_width], dim=0).cpu().numpy().astype(np.int32) seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8) _, labels, stats, centroids = cv2.connectedComponentsWithStats( seg_thr) component_num = stats.shape[0] all_res = [] for i in range(component_num): temp_loc = (labels == i) temp_value = seg_res[temp_loc] temp_center = centroids[i] temp_max_num = 0 temp_max_cls = -1 temp_total_num = 0 for c in range(len(self.idx2char)): c_num = np.sum(temp_value == c) temp_total_num += c_num if c_num > temp_max_num: temp_max_num = c_num temp_max_cls = c if temp_max_cls == 0: continue temp_max_score = 1.0 * temp_max_num / temp_total_num all_res.append( [temp_max_cls, temp_center, temp_max_num, temp_max_score]) all_res = sorted(all_res, key=lambda s: s[1][0]) chars, char_scores = [], [] for res in all_res: temp_area = res[2] if temp_area < 20: continue temp_char_index = res[0] if temp_char_index >= len(self.idx2char): temp_char = '' elif temp_char_index <= 0: temp_char = '' elif temp_char_index == self.unknown_idx: temp_char = '' else: temp_char = self.idx2char[temp_char_index] chars.append(temp_char) char_scores.append(res[3]) text = ''.join(chars) texts.append(text) scores.append(char_scores) return texts, scores