Spaces:
Runtime error
Runtime error
File size: 3,296 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
# Copyright (c) OpenMMLab. All rights reserved.
import cv2
import numpy as np
from mmocr.core import points2boundary
from mmocr.models.builder import POSTPROCESSOR
from .base_postprocessor import BasePostprocessor
from .utils import box_score_fast, unclip
@POSTPROCESSOR.register_module()
class DBPostprocessor(BasePostprocessor):
"""Decoding predictions of DbNet to instances. This is partially adapted
from https://github.com/MhLiao/DB.
Args:
text_repr_type (str): The boundary encoding type 'poly' or 'quad'.
mask_thr (float): The mask threshold value for binarization.
min_text_score (float): The threshold value for converting binary map
to shrink text regions.
min_text_width (int): The minimum width of boundary polygon/box
predicted.
unclip_ratio (float): The unclip ratio for text regions dilation.
max_candidates (int): The maximum candidate number.
"""
def __init__(self,
text_repr_type='poly',
mask_thr=0.3,
min_text_score=0.3,
min_text_width=5,
unclip_ratio=1.5,
max_candidates=3000,
**kwargs):
super().__init__(text_repr_type)
self.mask_thr = mask_thr
self.min_text_score = min_text_score
self.min_text_width = min_text_width
self.unclip_ratio = unclip_ratio
self.max_candidates = max_candidates
def __call__(self, preds):
"""
Args:
preds (Tensor): Prediction map with shape :math:`(C, H, W)`.
Returns:
list[list[float]]: The predicted text boundaries.
"""
assert preds.dim() == 3
prob_map = preds[0, :, :]
text_mask = prob_map > self.mask_thr
score_map = prob_map.data.cpu().numpy().astype(np.float32)
text_mask = text_mask.data.cpu().numpy().astype(np.uint8) # to numpy
contours, _ = cv2.findContours((text_mask * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
boundaries = []
for i, poly in enumerate(contours):
if i > self.max_candidates:
break
epsilon = 0.01 * cv2.arcLength(poly, True)
approx = cv2.approxPolyDP(poly, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
score = box_score_fast(score_map, points)
if score < self.min_text_score:
continue
poly = unclip(points, unclip_ratio=self.unclip_ratio)
if len(poly) == 0 or isinstance(poly[0], list):
continue
poly = poly.reshape(-1, 2)
if self.text_repr_type == 'quad':
poly = points2boundary(poly, self.text_repr_type, score,
self.min_text_width)
elif self.text_repr_type == 'poly':
poly = poly.flatten().tolist()
if score is not None:
poly = poly + [score]
if len(poly) < 8:
poly = None
if poly is not None:
boundaries.append(poly)
return boundaries
|