Spaces:
Runtime error
Runtime error
File size: 5,269 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch.nn as nn
from mmcv.runner import BaseModule
from mmdet.core import multi_apply
from mmocr.models.builder import HEADS
from ..postprocess.utils import poly_nms
from .head_mixin import HeadMixin
@HEADS.register_module()
class FCEHead(HeadMixin, BaseModule):
"""The class for implementing FCENet head.
FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text
Detection <https://arxiv.org/abs/2104.10442>`_
Args:
in_channels (int): The number of input channels.
scales (list[int]) : The scale of each layer.
fourier_degree (int) : The maximum Fourier transform degree k.
nms_thr (float) : The threshold of nms.
loss (dict): Config of loss for FCENet.
postprocessor (dict): Config of postprocessor for FCENet.
"""
def __init__(self,
in_channels,
scales,
fourier_degree=5,
nms_thr=0.1,
loss=dict(type='FCELoss', num_sample=50),
postprocessor=dict(
type='FCEPostprocessor',
text_repr_type='poly',
num_reconstr_points=50,
alpha=1.0,
beta=2.0,
score_thr=0.3),
train_cfg=None,
test_cfg=None,
init_cfg=dict(
type='Normal',
mean=0,
std=0.01,
override=[
dict(name='out_conv_cls'),
dict(name='out_conv_reg')
]),
**kwargs):
old_keys = [
'text_repr_type', 'decoding_type', 'num_reconstr_points', 'alpha',
'beta', 'score_thr'
]
for key in old_keys:
if kwargs.get(key, None):
postprocessor[key] = kwargs.get(key)
warnings.warn(
f'{key} is deprecated, please specify '
'it in postprocessor config dict. See '
'https://github.com/open-mmlab/mmocr/pull/640'
' for details.', UserWarning)
if kwargs.get('num_sample', None):
loss['num_sample'] = kwargs.get('num_sample')
warnings.warn(
'num_sample is deprecated, please specify '
'it in loss config dict. See '
'https://github.com/open-mmlab/mmocr/pull/640'
' for details.', UserWarning)
BaseModule.__init__(self, init_cfg=init_cfg)
loss['fourier_degree'] = fourier_degree
postprocessor['fourier_degree'] = fourier_degree
postprocessor['nms_thr'] = nms_thr
HeadMixin.__init__(self, loss, postprocessor)
assert isinstance(in_channels, int)
self.downsample_ratio = 1.0
self.in_channels = in_channels
self.scales = scales
self.fourier_degree = fourier_degree
self.nms_thr = nms_thr
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.out_channels_cls = 4
self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
self.out_conv_cls = nn.Conv2d(
self.in_channels,
self.out_channels_cls,
kernel_size=3,
stride=1,
padding=1)
self.out_conv_reg = nn.Conv2d(
self.in_channels,
self.out_channels_reg,
kernel_size=3,
stride=1,
padding=1)
def forward(self, feats):
"""
Args:
feats (list[Tensor]): Each tensor has the shape of :math:`(N, C_i,
H_i, W_i)`.
Returns:
list[[Tensor, Tensor]]: Each pair of tensors corresponds to the
classification result and regression result computed from the input
tensor with the same index. They have the shapes of :math:`(N,
C_{cls,i}, H_i, W_i)` and :math:`(N, C_{out,i}, H_i, W_i)`.
"""
cls_res, reg_res = multi_apply(self.forward_single, feats)
level_num = len(cls_res)
preds = [[cls_res[i], reg_res[i]] for i in range(level_num)]
return preds
def forward_single(self, x):
cls_predict = self.out_conv_cls(x)
reg_predict = self.out_conv_reg(x)
return cls_predict, reg_predict
def get_boundary(self, score_maps, img_metas, rescale):
assert len(score_maps) == len(self.scales)
boundaries = []
for idx, score_map in enumerate(score_maps):
scale = self.scales[idx]
boundaries = boundaries + self._get_boundary_single(
score_map, scale)
# nms
boundaries = poly_nms(boundaries, self.nms_thr)
if rescale:
boundaries = self.resize_boundary(
boundaries, 1.0 / img_metas[0]['scale_factor'])
results = dict(boundary_result=boundaries)
return results
def _get_boundary_single(self, score_map, scale):
assert len(score_map) == 2
assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
return self.postprocessor(score_map, scale)
|