Aku Rouhe
commited on
Commit
•
0f77019
1
Parent(s):
8e970cf
small fix
Browse files
custom.py
CHANGED
@@ -6,7 +6,7 @@ class FeatureScaler(torch.nn.Module):
|
|
6 |
super().__init__()
|
7 |
self.scaler = torch.eye(num_in) * scale
|
8 |
|
9 |
-
def forward(x):
|
10 |
return x * self.scaler
|
11 |
|
12 |
class CustomInterface(sb.pretrained.interfaces.Pretrained):
|
|
|
6 |
super().__init__()
|
7 |
self.scaler = torch.eye(num_in) * scale
|
8 |
|
9 |
+
def forward(self, x):
|
10 |
return x * self.scaler
|
11 |
|
12 |
class CustomInterface(sb.pretrained.interfaces.Pretrained):
|