Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
class PositionAwareLayer(nn.Module): | |
def __init__(self, dim_model, rnn_layers=2): | |
super().__init__() | |
self.dim_model = dim_model | |
self.rnn = nn.LSTM( | |
input_size=dim_model, | |
hidden_size=dim_model, | |
num_layers=rnn_layers, | |
batch_first=True) | |
self.mixer = nn.Sequential( | |
nn.Conv2d( | |
dim_model, dim_model, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(True), | |
nn.Conv2d( | |
dim_model, dim_model, kernel_size=3, stride=1, padding=1)) | |
def forward(self, img_feature): | |
n, c, h, w = img_feature.size() | |
rnn_input = img_feature.permute(0, 2, 3, 1).contiguous() | |
rnn_input = rnn_input.view(n * h, w, c) | |
rnn_output, _ = self.rnn(rnn_input) | |
rnn_output = rnn_output.view(n, h, w, c) | |
rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous() | |
out = self.mixer(rnn_output) | |
return out | |