Crystalcareai commited on
Commit
7e59a14
1 Parent(s): cd6e834

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +3 -2
modeling_quiet.py CHANGED
@@ -942,9 +942,10 @@ class QuietModel(QuietPreTrainedModel):
942
  inputs_embeds=thought_embedding,
943
  attention_mask=None,
944
  use_cache=True,
 
945
  )
946
- logits = outputs.logits[:, -1, :]
947
- next_token_id = torch.argmax(logits, dim=-1)
948
 
949
  if next_token_id == self.config.end_token_id:
950
  break
 
942
  inputs_embeds=thought_embedding,
943
  attention_mask=None,
944
  use_cache=True,
945
+ return_dict=True, # Set return_dict=True
946
  )
947
+ logits = self.lm_head(outputs.last_hidden_state) # Use outputs.last_hidden_state instead of outputs.logits
948
+ next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
949
 
950
  if next_token_id == self.config.end_token_id:
951
  break