fix: force correct mixed dtype after HF load
Browse files- modeling_hyena.py +2 -1
modeling_hyena.py
CHANGED
@@ -45,8 +45,9 @@ class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
|
|
45 |
)
|
46 |
self.vocab_size = vocab_size
|
47 |
self.post_init()
|
|
|
48 |
|
49 |
-
def
|
50 |
self.backbone.to_bfloat16_except_poles_residues()
|
51 |
|
52 |
def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
|
|
|
45 |
)
|
46 |
self.vocab_size = vocab_size
|
47 |
self.post_init()
|
48 |
+
self.force_dtype()
|
49 |
|
50 |
+
def force_dtype(self):
|
51 |
self.backbone.to_bfloat16_except_poles_residues()
|
52 |
|
53 |
def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
|