Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import cv2 | |
import numpy as np | |
import torch | |
from skimage.morphology import skeletonize | |
from mmocr.models.builder import POSTPROCESSOR | |
from .base_postprocessor import BasePostprocessor | |
from .utils import centralize, fill_hole, merge_disks | |
class TextSnakePostprocessor(BasePostprocessor): | |
"""Decoding predictions of TextSnake to instances. This was partially | |
adapted from https://github.com/princewang1994/TextSnake.pytorch. | |
Args: | |
text_repr_type (str): The boundary encoding type 'poly' or 'quad'. | |
min_text_region_confidence (float): The confidence threshold of text | |
region in TextSnake. | |
min_center_region_confidence (float): The confidence threshold of text | |
center region in TextSnake. | |
min_center_area (int): The minimal text center region area. | |
disk_overlap_thr (float): The radius overlap threshold for merging | |
disks. | |
radius_shrink_ratio (float): The shrink ratio of ordered disks radii. | |
""" | |
def __init__(self, | |
text_repr_type='poly', | |
min_text_region_confidence=0.6, | |
min_center_region_confidence=0.2, | |
min_center_area=30, | |
disk_overlap_thr=0.03, | |
radius_shrink_ratio=1.03, | |
**kwargs): | |
super().__init__(text_repr_type) | |
assert text_repr_type == 'poly' | |
self.min_text_region_confidence = min_text_region_confidence | |
self.min_center_region_confidence = min_center_region_confidence | |
self.min_center_area = min_center_area | |
self.disk_overlap_thr = disk_overlap_thr | |
self.radius_shrink_ratio = radius_shrink_ratio | |
def __call__(self, preds): | |
""" | |
Args: | |
preds (Tensor): Prediction map with shape :math:`(C, H, W)`. | |
Returns: | |
list[list[float]]: The instance boundary and its confidence. | |
""" | |
assert preds.dim() == 3 | |
preds[:2, :, :] = torch.sigmoid(preds[:2, :, :]) | |
preds = preds.detach().cpu().numpy() | |
pred_text_score = preds[0] | |
pred_text_mask = pred_text_score > self.min_text_region_confidence | |
pred_center_score = preds[1] * pred_text_score | |
pred_center_mask = \ | |
pred_center_score > self.min_center_region_confidence | |
pred_sin = preds[2] | |
pred_cos = preds[3] | |
pred_radius = preds[4] | |
mask_sz = pred_text_mask.shape | |
scale = np.sqrt(1.0 / (pred_sin**2 + pred_cos**2 + 1e-8)) | |
pred_sin = pred_sin * scale | |
pred_cos = pred_cos * scale | |
pred_center_mask = fill_hole(pred_center_mask).astype(np.uint8) | |
center_contours, _ = cv2.findContours(pred_center_mask, cv2.RETR_TREE, | |
cv2.CHAIN_APPROX_SIMPLE) | |
boundaries = [] | |
for contour in center_contours: | |
if cv2.contourArea(contour) < self.min_center_area: | |
continue | |
instance_center_mask = np.zeros(mask_sz, dtype=np.uint8) | |
cv2.drawContours(instance_center_mask, [contour], -1, 1, -1) | |
skeleton = skeletonize(instance_center_mask) | |
skeleton_yx = np.argwhere(skeleton > 0) | |
y, x = skeleton_yx[:, 0], skeleton_yx[:, 1] | |
cos = pred_cos[y, x].reshape((-1, 1)) | |
sin = pred_sin[y, x].reshape((-1, 1)) | |
radius = pred_radius[y, x].reshape((-1, 1)) | |
center_line_yx = centralize(skeleton_yx, cos, -sin, radius, | |
instance_center_mask) | |
y, x = center_line_yx[:, 0], center_line_yx[:, 1] | |
radius = (pred_radius[y, x] * self.radius_shrink_ratio).reshape( | |
(-1, 1)) | |
score = pred_center_score[y, x].reshape((-1, 1)) | |
instance_disks = np.hstack( | |
[np.fliplr(center_line_yx), radius, score]) | |
instance_disks = merge_disks(instance_disks, self.disk_overlap_thr) | |
instance_mask = np.zeros(mask_sz, dtype=np.uint8) | |
for x, y, radius, score in instance_disks: | |
if radius > 1: | |
cv2.circle(instance_mask, (int(x), int(y)), int(radius), 1, | |
-1) | |
contours, _ = cv2.findContours(instance_mask, cv2.RETR_TREE, | |
cv2.CHAIN_APPROX_SIMPLE) | |
score = np.sum(instance_mask * pred_text_score) / ( | |
np.sum(instance_mask) + 1e-8) | |
if (len(contours) > 0 and cv2.contourArea(contours[0]) > 0 | |
and contours[0].size > 8): | |
boundary = contours[0].flatten().tolist() | |
boundaries.append(boundary + [score]) | |
return boundaries | |