Spaces:
Running
on
T4
Running
on
T4
File size: 546 Bytes
a277bb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from torch import nn
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
dropout: float,
activation: nn.Module
):
super(MLP, self).__init__()
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, input_dim)
self.dropout = nn.Dropout(dropout)
self.activation = activation()
def forward(self, x):
return (
self.linear2(self.dropout(self.activation(self.linear1(x))))
)
|