MMOCR / mmocr /models /textrecog /layers /position_aware_layer.py
tomofi's picture
Add application file
2366e36
raw
history blame
1.05 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
class PositionAwareLayer(nn.Module):
def __init__(self, dim_model, rnn_layers=2):
super().__init__()
self.dim_model = dim_model
self.rnn = nn.LSTM(
input_size=dim_model,
hidden_size=dim_model,
num_layers=rnn_layers,
batch_first=True)
self.mixer = nn.Sequential(
nn.Conv2d(
dim_model, dim_model, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(
dim_model, dim_model, kernel_size=3, stride=1, padding=1))
def forward(self, img_feature):
n, c, h, w = img_feature.size()
rnn_input = img_feature.permute(0, 2, 3, 1).contiguous()
rnn_input = rnn_input.view(n * h, w, c)
rnn_output, _ = self.rnn(rnn_input)
rnn_output = rnn_output.view(n, h, w, c)
rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous()
out = self.mixer(rnn_output)
return out