from torch import nn class MLP(nn.Module): def __init__( self, input_dim: int, hidden_dim: int, dropout: float, activation: nn.Module ): super(MLP, self).__init__() self.linear1 = nn.Linear(input_dim, hidden_dim) self.linear2 = nn.Linear(hidden_dim, input_dim) self.dropout = nn.Dropout(dropout) self.activation = activation() def forward(self, x): return ( self.linear2(self.dropout(self.activation(self.linear1(x)))) )