Update modelling_uniformer.py
Browse files- modelling_uniformer.py +2 -2
modelling_uniformer.py
CHANGED
@@ -99,8 +99,8 @@ class CBlock(nn.Module):
|
|
99 |
|
100 |
self.ls = layer_scale
|
101 |
if self.ls:
|
102 |
-
self.gamma_1 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
|
103 |
-
self.gamma_2 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
|
104 |
|
105 |
def forward(self, x):
|
106 |
x = x + self.pos_embed(x)
|
|
|
99 |
|
100 |
self.ls = layer_scale
|
101 |
if self.ls:
|
102 |
+
self.gamma_1 = nn.Parameter(init_value * torch.ones((1, dim, 1, 1)),requires_grad=True)
|
103 |
+
self.gamma_2 = nn.Parameter(init_value * torch.ones((1, dim, 1, 1)),requires_grad=True)
|
104 |
|
105 |
def forward(self, x):
|
106 |
x = x + self.pos_embed(x)
|