Crystalcareai
commited on
Commit
•
90a26fc
1
Parent(s):
331e42c
Update modeling_quiet.py
Browse files- modeling_quiet.py +2 -4
modeling_quiet.py
CHANGED
@@ -688,15 +688,13 @@ class QuietSdpaAttention(QuietAttention):
|
|
688 |
value_states = value_states.contiguous()
|
689 |
|
690 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
691 |
-
query_states,
|
692 |
-
key_states,
|
693 |
-
value_states,
|
694 |
attn_mask=attention_mask,
|
695 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
696 |
-
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
697 |
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
698 |
)
|
699 |
|
|
|
700 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
701 |
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
702 |
|
|
|
688 |
value_states = value_states.contiguous()
|
689 |
|
690 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
691 |
+
query_states, key_states, value_states,
|
|
|
|
|
692 |
attn_mask=attention_mask,
|
693 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
|
|
694 |
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
695 |
)
|
696 |
|
697 |
+
|
698 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
699 |
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
700 |
|