Crystalcareai commited on
Commit
90a26fc
1 Parent(s): 331e42c

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +2 -4
modeling_quiet.py CHANGED
@@ -688,15 +688,13 @@ class QuietSdpaAttention(QuietAttention):
688
  value_states = value_states.contiguous()
689
 
690
  attn_output = torch.nn.functional.scaled_dot_product_attention(
691
- query_states,
692
- key_states,
693
- value_states,
694
  attn_mask=attention_mask,
695
  dropout_p=self.attention_dropout if self.training else 0.0,
696
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
697
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
698
  )
699
 
 
700
  attn_output = attn_output.transpose(1, 2).contiguous()
701
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
702
 
 
688
  value_states = value_states.contiguous()
689
 
690
  attn_output = torch.nn.functional.scaled_dot_product_attention(
691
+ query_states, key_states, value_states,
 
 
692
  attn_mask=attention_mask,
693
  dropout_p=self.attention_dropout if self.training else 0.0,
 
694
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
695
  )
696
 
697
+
698
  attn_output = attn_output.transpose(1, 2).contiguous()
699
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
700