File size: 2,022 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch import nn

from mmocr.models.builder import HEADS


@HEADS.register_module()
class SegHead(BaseModule):
    """Head for segmentation based text recognition.

    Args:
        in_channels (int): Number of input channels :math:`C`.
        num_classes (int): Number of output classes :math:`C_{out}`.
        upsample_param (dict | None): Config dict for interpolation layer.
            Default: ``dict(scale_factor=1.0, mode='nearest')``
        init_cfg (dict or list[dict], optional): Initialization configs.
    """

    def __init__(self,
                 in_channels=128,
                 num_classes=37,
                 upsample_param=None,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        assert isinstance(num_classes, int)
        assert num_classes > 0
        assert upsample_param is None or isinstance(upsample_param, dict)

        self.upsample_param = upsample_param

        self.seg_conv = ConvModule(
            in_channels,
            in_channels,
            3,
            stride=1,
            padding=1,
            norm_cfg=dict(type='BN'))

        # prediction
        self.pred_conv = nn.Conv2d(
            in_channels, num_classes, kernel_size=1, stride=1, padding=0)

    def forward(self, out_neck):
        """
        Args:
            out_neck (list[Tensor]): A list of tensor of shape
                :math:`(N, C_i, H_i, W_i)`. The network only uses the last one
                (``out_neck[-1]``).

        Returns:
            Tensor: A tensor of shape :math:`(N, C_{out}, kH, kW)` where
            :math:`k` is determined by ``upsample_param``.
        """

        seg_map = self.seg_conv(out_neck[-1])
        seg_map = self.pred_conv(seg_map)

        if self.upsample_param is not None:
            seg_map = F.interpolate(seg_map, **self.upsample_param)

        return seg_map