Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
class BidirectionalLSTM(nn.Module): | |
def __init__(self, nIn, nHidden, nOut): | |
super().__init__() | |
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) | |
self.embedding = nn.Linear(nHidden * 2, nOut) | |
def forward(self, input): | |
recurrent, _ = self.rnn(input) | |
T, b, h = recurrent.size() | |
t_rec = recurrent.view(T * b, h) | |
output = self.embedding(t_rec) # [T * b, nOut] | |
output = output.view(T, b, -1) | |
return output | |