Spaces:
Starting
on
T4
Starting
on
T4
from torch import nn | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
hidden_dim: int, | |
dropout: float, | |
activation: nn.Module | |
): | |
super(MLP, self).__init__() | |
self.linear1 = nn.Linear(input_dim, hidden_dim) | |
self.linear2 = nn.Linear(hidden_dim, input_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.activation = activation() | |
def forward(self, x): | |
return ( | |
self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
) | |