Spaces:
Running
Running
handle dtype for embeddings
Browse files
dalle_mini/modeling_bart_flax.py
CHANGED
@@ -461,8 +461,10 @@ class FlaxBartEncoder(nn.Module):
|
|
461 |
input_ids = input_ids.reshape(-1, input_shape[-1])
|
462 |
|
463 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
|
|
464 |
|
465 |
embed_pos = self.embed_positions(position_ids + self.offset)
|
|
|
466 |
|
467 |
hidden_states = inputs_embeds + embed_pos
|
468 |
hidden_states = self.layernorm_embedding(hidden_states)
|
@@ -521,9 +523,11 @@ class FlaxBartDecoder(nn.Module):
|
|
521 |
input_ids = input_ids.reshape(-1, input_shape[-1])
|
522 |
|
523 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
|
|
524 |
|
525 |
# embed positions
|
526 |
positions = self.embed_positions(position_ids + self.offset)
|
|
|
527 |
|
528 |
hidden_states = inputs_embeds + positions
|
529 |
hidden_states = self.layernorm_embedding(hidden_states)
|
|
|
461 |
input_ids = input_ids.reshape(-1, input_shape[-1])
|
462 |
|
463 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
464 |
+
inputs_embeds = inputs_embeds.astype(self.dtype)
|
465 |
|
466 |
embed_pos = self.embed_positions(position_ids + self.offset)
|
467 |
+
embed_pos = embed_pos.astype(self.dtype)
|
468 |
|
469 |
hidden_states = inputs_embeds + embed_pos
|
470 |
hidden_states = self.layernorm_embedding(hidden_states)
|
|
|
523 |
input_ids = input_ids.reshape(-1, input_shape[-1])
|
524 |
|
525 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
526 |
+
inputs_embeds = inputs_embeds.astype(self.dtype)
|
527 |
|
528 |
# embed positions
|
529 |
positions = self.embed_positions(position_ids + self.offset)
|
530 |
+
positions = positions.astype(self.dtype)
|
531 |
|
532 |
hidden_states = inputs_embeds + positions
|
533 |
hidden_states = self.layernorm_embedding(hidden_states)
|