MMOCR / mmocr /models /textdet /postprocess /textsnake_postprocessor.py
tomofi's picture
Add application file
2366e36
raw
history blame
4.81 kB
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
import torch
from skimage.morphology import skeletonize
from mmocr.models.builder import POSTPROCESSOR
from .base_postprocessor import BasePostprocessor
from .utils import centralize, fill_hole, merge_disks
@POSTPROCESSOR.register_module()
class TextSnakePostprocessor(BasePostprocessor):
"""Decoding predictions of TextSnake to instances. This was partially
adapted from https://github.com/princewang1994/TextSnake.pytorch.
Args:
text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
min_text_region_confidence (float): The confidence threshold of text
region in TextSnake.
min_center_region_confidence (float): The confidence threshold of text
center region in TextSnake.
min_center_area (int): The minimal text center region area.
disk_overlap_thr (float): The radius overlap threshold for merging
disks.
radius_shrink_ratio (float): The shrink ratio of ordered disks radii.
"""
def __init__(self,
text_repr_type='poly',
min_text_region_confidence=0.6,
min_center_region_confidence=0.2,
min_center_area=30,
disk_overlap_thr=0.03,
radius_shrink_ratio=1.03,
**kwargs):
super().__init__(text_repr_type)
assert text_repr_type == 'poly'
self.min_text_region_confidence = min_text_region_confidence
self.min_center_region_confidence = min_center_region_confidence
self.min_center_area = min_center_area
self.disk_overlap_thr = disk_overlap_thr
self.radius_shrink_ratio = radius_shrink_ratio
def __call__(self, preds):
"""
Args:
preds (Tensor): Prediction map with shape :math:`(C, H, W)`.
Returns:
list[list[float]]: The instance boundary and its confidence.
"""
assert preds.dim() == 3
preds[:2, :, :] = torch.sigmoid(preds[:2, :, :])
preds = preds.detach().cpu().numpy()
pred_text_score = preds[0]
pred_text_mask = pred_text_score > self.min_text_region_confidence
pred_center_score = preds[1] * pred_text_score
pred_center_mask = \
pred_center_score > self.min_center_region_confidence
pred_sin = preds[2]
pred_cos = preds[3]
pred_radius = preds[4]
mask_sz = pred_text_mask.shape
scale = np.sqrt(1.0 / (pred_sin**2 + pred_cos**2 + 1e-8))
pred_sin = pred_sin * scale
pred_cos = pred_cos * scale
pred_center_mask = fill_hole(pred_center_mask).astype(np.uint8)
center_contours, _ = cv2.findContours(pred_center_mask, cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
boundaries = []
for contour in center_contours:
if cv2.contourArea(contour) < self.min_center_area:
continue
instance_center_mask = np.zeros(mask_sz, dtype=np.uint8)
cv2.drawContours(instance_center_mask, [contour], -1, 1, -1)
skeleton = skeletonize(instance_center_mask)
skeleton_yx = np.argwhere(skeleton > 0)
y, x = skeleton_yx[:, 0], skeleton_yx[:, 1]
cos = pred_cos[y, x].reshape((-1, 1))
sin = pred_sin[y, x].reshape((-1, 1))
radius = pred_radius[y, x].reshape((-1, 1))
center_line_yx = centralize(skeleton_yx, cos, -sin, radius,
instance_center_mask)
y, x = center_line_yx[:, 0], center_line_yx[:, 1]
radius = (pred_radius[y, x] * self.radius_shrink_ratio).reshape(
(-1, 1))
score = pred_center_score[y, x].reshape((-1, 1))
instance_disks = np.hstack(
[np.fliplr(center_line_yx), radius, score])
instance_disks = merge_disks(instance_disks, self.disk_overlap_thr)
instance_mask = np.zeros(mask_sz, dtype=np.uint8)
for x, y, radius, score in instance_disks:
if radius > 1:
cv2.circle(instance_mask, (int(x), int(y)), int(radius), 1,
-1)
contours, _ = cv2.findContours(instance_mask, cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
score = np.sum(instance_mask * pred_text_score) / (
np.sum(instance_mask) + 1e-8)
if (len(contours) > 0 and cv2.contourArea(contours[0]) > 0
and contours[0].size > 8):
boundary = contours[0].flatten().tolist()
boundaries.append(boundary + [score])
return boundaries