tomofi's picture
Add application file
2366e36
raw
history blame
561 Bytes
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super().__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output