File size: 6,223 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmocr.models.builder import DECODERS
from mmocr.models.common.modules import PositionalEncoding
from .base_decoder import BaseDecoder


@DECODERS.register_module()
class ABIVisionDecoder(BaseDecoder):
    """Converts visual features into text characters.

    Implementation of VisionEncoder in
        `ABINet <https://arxiv.org/abs/1910.04396>`_.

    Args:
        in_channels (int): Number of channels :math:`E` of input vector.
        num_channels (int): Number of channels of hidden vectors in mini U-Net.
        h (int): Height :math:`H` of input image features.
        w (int): Width :math:`W` of input image features.

        in_channels (int): Number of channels of input image features.
        num_channels (int): Number of channels of hidden vectors in mini U-Net.
        attn_height (int): Height :math:`H` of input image features.
        attn_width (int): Width :math:`W` of input image features.
        attn_mode (str): Upsampling mode for :obj:`torch.nn.Upsample` in mini
            U-Net.
        max_seq_len (int): Maximum text sequence length :math:`T`.
        num_chars (int): Number of text characters :math:`C`.
        init_cfg (dict): Specifies the initialization method for model layers.
    """

    def __init__(self,
                 in_channels=512,
                 num_channels=64,
                 attn_height=8,
                 attn_width=32,
                 attn_mode='nearest',
                 max_seq_len=40,
                 num_chars=90,
                 init_cfg=dict(type='Xavier', layer='Conv2d'),
                 **kwargs):
        super().__init__(init_cfg=init_cfg)

        self.max_seq_len = max_seq_len

        # For mini-Unet
        self.k_encoder = nn.Sequential(
            self._encoder_layer(in_channels, num_channels, stride=(1, 2)),
            self._encoder_layer(num_channels, num_channels, stride=(2, 2)),
            self._encoder_layer(num_channels, num_channels, stride=(2, 2)),
            self._encoder_layer(num_channels, num_channels, stride=(2, 2)))

        self.k_decoder = nn.Sequential(
            self._decoder_layer(
                num_channels, num_channels, scale_factor=2, mode=attn_mode),
            self._decoder_layer(
                num_channels, num_channels, scale_factor=2, mode=attn_mode),
            self._decoder_layer(
                num_channels, num_channels, scale_factor=2, mode=attn_mode),
            self._decoder_layer(
                num_channels,
                in_channels,
                size=(attn_height, attn_width),
                mode=attn_mode))

        self.pos_encoder = PositionalEncoding(in_channels, max_seq_len)
        self.project = nn.Linear(in_channels, in_channels)
        self.cls = nn.Linear(in_channels, num_chars)

    def forward_train(self,
                      feat,
                      out_enc=None,
                      targets_dict=None,
                      img_metas=None):
        """
        Args:
            feat (Tensor): Image features of shape (N, E, H, W).

        Returns:
            dict: A dict with keys ``feature``, ``logits`` and ``attn_scores``.

            - | feature (Tensor): Shape (N, T, E). Raw visual features for
                language decoder.
            - | logits (Tensor): Shape (N, T, C). The raw logits for
                characters.
            - | attn_scores (Tensor): Shape (N, T, H, W). Intermediate result
                for vision-language aligner.
        """
        # Position Attention
        N, E, H, W = feat.size()
        k, v = feat, feat  # (N, E, H, W)

        # Apply mini U-Net on k
        features = []
        for i in range(len(self.k_encoder)):
            k = self.k_encoder[i](k)
            features.append(k)
        for i in range(len(self.k_decoder) - 1):
            k = self.k_decoder[i](k)
            k = k + features[len(self.k_decoder) - 2 - i]
        k = self.k_decoder[-1](k)

        # q = positional encoding
        zeros = feat.new_zeros((N, self.max_seq_len, E))  # (N, T, E)
        q = self.pos_encoder(zeros)  # (N, T, E)
        q = self.project(q)  # (N, T, E)

        # Attention encoding
        attn_scores = torch.bmm(q, k.flatten(2, 3))  # (N, T, (H*W))
        attn_scores = attn_scores / (E**0.5)
        attn_scores = torch.softmax(attn_scores, dim=-1)
        v = v.permute(0, 2, 3, 1).view(N, -1, E)  # (N, (H*W), E)
        attn_vecs = torch.bmm(attn_scores, v)  # (N, T, E)

        logits = self.cls(attn_vecs)
        result = {
            'feature': attn_vecs,
            'logits': logits,
            'attn_scores': attn_scores.view(N, -1, H, W)
        }
        return result

    def forward_test(self, feat, out_enc=None, img_metas=None):
        return self.forward_train(feat, out_enc=out_enc, img_metas=img_metas)

    def _encoder_layer(self,
                       in_channels,
                       out_channels,
                       kernel_size=3,
                       stride=2,
                       padding=1):
        return ConvModule(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            norm_cfg=dict(type='BN'),
            act_cfg=dict(type='ReLU'))

    def _decoder_layer(self,
                       in_channels,
                       out_channels,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       mode='nearest',
                       scale_factor=None,
                       size=None):
        align_corners = None if mode == 'nearest' else True
        return nn.Sequential(
            nn.Upsample(
                size=size,
                scale_factor=scale_factor,
                mode=mode,
                align_corners=align_corners),
            ConvModule(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                norm_cfg=dict(type='BN'),
                act_cfg=dict(type='ReLU')))