Spaces:
Runtime error
Runtime error
File size: 1,052 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
class PositionAwareLayer(nn.Module):
def __init__(self, dim_model, rnn_layers=2):
super().__init__()
self.dim_model = dim_model
self.rnn = nn.LSTM(
input_size=dim_model,
hidden_size=dim_model,
num_layers=rnn_layers,
batch_first=True)
self.mixer = nn.Sequential(
nn.Conv2d(
dim_model, dim_model, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(
dim_model, dim_model, kernel_size=3, stride=1, padding=1))
def forward(self, img_feature):
n, c, h, w = img_feature.size()
rnn_input = img_feature.permute(0, 2, 3, 1).contiguous()
rnn_input = rnn_input.view(n * h, w, c)
rnn_output, _ = self.rnn(rnn_input)
rnn_output = rnn_output.view(n, h, w, c)
rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous()
out = self.mixer(rnn_output)
return out
|