File size: 582 Bytes
cffb9c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, in_out_features, hidden_features=512, drop=0.2):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(in_out_features, hidden_features),
            nn.BatchNorm1d(hidden_features),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden_features, in_out_features),
            nn.BatchNorm1d(in_out_features),
            nn.GELU(),
            nn.Dropout(drop),
        )

    def forward(self, x):
        return self.classifier(x)