winfred2027 commited on
Commit
cffb9c2
1 Parent(s): 0c9a1d7

Upload mlp.py

Browse files
Files changed (1) hide show
  1. openshape/mlp.py +18 -0
openshape/mlp.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class MLP(nn.Module):
4
+ def __init__(self, in_out_features, hidden_features=512, drop=0.2):
5
+ super().__init__()
6
+ self.classifier = nn.Sequential(
7
+ nn.Linear(in_out_features, hidden_features),
8
+ nn.BatchNorm1d(hidden_features),
9
+ nn.GELU(),
10
+ nn.Dropout(drop),
11
+ nn.Linear(hidden_features, in_out_features),
12
+ nn.BatchNorm1d(in_out_features),
13
+ nn.GELU(),
14
+ nn.Dropout(drop),
15
+ )
16
+
17
+ def forward(self, x):
18
+ return self.classifier(x)