Spaces:
Runtime error
Runtime error
File size: 3,173 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, ModuleList
from mmocr.models.builder import NECKS
@NECKS.register_module()
class FPNOCR(BaseModule):
"""FPN-like Network for segmentation based text recognition.
Args:
in_channels (list[int]): Number of input channels :math:`C_i` for each
scale.
out_channels (int): Number of output channels :math:`C_{out}` for each
scale.
last_stage_only (bool): If True, output last stage only.
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(self,
in_channels,
out_channels,
last_stage_only=True,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.last_stage_only = last_stage_only
self.lateral_convs = ModuleList()
self.smooth_convs_1x1 = ModuleList()
self.smooth_convs_3x3 = ModuleList()
for i in range(self.num_ins):
l_conv = ConvModule(
in_channels[i], out_channels, 1, norm_cfg=dict(type='BN'))
self.lateral_convs.append(l_conv)
for i in range(self.num_ins - 1):
s_conv_1x1 = ConvModule(
out_channels * 2, out_channels, 1, norm_cfg=dict(type='BN'))
s_conv_3x3 = ConvModule(
out_channels,
out_channels,
3,
padding=1,
norm_cfg=dict(type='BN'))
self.smooth_convs_1x1.append(s_conv_1x1)
self.smooth_convs_3x3.append(s_conv_3x3)
def _upsample_x2(self, x):
return F.interpolate(x, scale_factor=2, mode='bilinear')
def forward(self, inputs):
"""
Args:
inputs (list[Tensor]): A list of n tensors. Each tensor has the
shape of :math:`(N, C_i, H_i, W_i)`. It usually expects 4
tensors (C2-C5 features) from ResNet.
Returns:
tuple(Tensor): A tuple of n-1 tensors. Each has the of shape
:math:`(N, C_{out}, H_{n-2-i}, W_{n-2-i})`. If
``last_stage_only=True`` (default), the size of the
tuple is 1 and only the last element will be returned.
"""
lateral_features = [
l_conv(inputs[i]) for i, l_conv in enumerate(self.lateral_convs)
]
outs = []
for i in range(len(self.smooth_convs_3x3), 0, -1): # 3, 2, 1
last_out = lateral_features[-1] if len(outs) == 0 else outs[-1]
upsample = self._upsample_x2(last_out)
upsample_cat = torch.cat((upsample, lateral_features[i - 1]),
dim=1)
smooth_1x1 = self.smooth_convs_1x1[i - 1](upsample_cat)
smooth_3x3 = self.smooth_convs_3x3[i - 1](smooth_1x1)
outs.append(smooth_3x3)
return tuple(outs[-1:]) if self.last_stage_only else tuple(outs)
|