valhalla commited on
Commit
29db327
1 Parent(s): 95a8ed2

handle dtype for embeddings

Browse files
Files changed (1) hide show
  1. dalle_mini/modeling_bart_flax.py +4 -0
dalle_mini/modeling_bart_flax.py CHANGED
@@ -461,8 +461,10 @@ class FlaxBartEncoder(nn.Module):
461
  input_ids = input_ids.reshape(-1, input_shape[-1])
462
 
463
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
 
464
 
465
  embed_pos = self.embed_positions(position_ids + self.offset)
 
466
 
467
  hidden_states = inputs_embeds + embed_pos
468
  hidden_states = self.layernorm_embedding(hidden_states)
@@ -521,9 +523,11 @@ class FlaxBartDecoder(nn.Module):
521
  input_ids = input_ids.reshape(-1, input_shape[-1])
522
 
523
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
 
524
 
525
  # embed positions
526
  positions = self.embed_positions(position_ids + self.offset)
 
527
 
528
  hidden_states = inputs_embeds + positions
529
  hidden_states = self.layernorm_embedding(hidden_states)
 
461
  input_ids = input_ids.reshape(-1, input_shape[-1])
462
 
463
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
464
+ inputs_embeds = inputs_embeds.astype(self.dtype)
465
 
466
  embed_pos = self.embed_positions(position_ids + self.offset)
467
+ embed_pos = embed_pos.astype(self.dtype)
468
 
469
  hidden_states = inputs_embeds + embed_pos
470
  hidden_states = self.layernorm_embedding(hidden_states)
 
523
  input_ids = input_ids.reshape(-1, input_shape[-1])
524
 
525
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
526
+ inputs_embeds = inputs_embeds.astype(self.dtype)
527
 
528
  # embed positions
529
  positions = self.embed_positions(position_ids + self.offset)
530
+ positions = positions.astype(self.dtype)
531
 
532
  hidden_states = inputs_embeds + positions
533
  hidden_states = self.layernorm_embedding(hidden_states)