Dudu Moshe commited on
Commit
fc02e02
2 Parent(s): b30014f a3498bb

Merge pull request #7 from LightricksResearch/feature/fix-transformer-init-bug

Browse files
xora/models/transformers/transformer3d.py CHANGED
@@ -186,14 +186,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
186
 
187
  # Zero-out adaLN modulation layers in PixArt blocks:
188
  for block in self.transformer_blocks:
189
- if mode == "xora":
190
  nn.init.constant_(block.attn1.to_out[0].weight, 0)
191
  nn.init.constant_(block.attn1.to_out[0].bias, 0)
192
 
193
  nn.init.constant_(block.attn2.to_out[0].weight, 0)
194
  nn.init.constant_(block.attn2.to_out[0].bias, 0)
195
 
196
- if mode == "xora":
197
  nn.init.constant_(block.ff.net[2].weight, 0)
198
  nn.init.constant_(block.ff.net[2].bias, 0)
199
 
 
186
 
187
  # Zero-out adaLN modulation layers in PixArt blocks:
188
  for block in self.transformer_blocks:
189
+ if mode.lower() == "xora":
190
  nn.init.constant_(block.attn1.to_out[0].weight, 0)
191
  nn.init.constant_(block.attn1.to_out[0].bias, 0)
192
 
193
  nn.init.constant_(block.attn2.to_out[0].weight, 0)
194
  nn.init.constant_(block.attn2.to_out[0].bias, 0)
195
 
196
+ if mode.lower() == "xora":
197
  nn.init.constant_(block.ff.net[2].weight, 0)
198
  nn.init.constant_(block.ff.net[2].bias, 0)
199