Crystalcareai
commited on
Commit
•
44b539c
1
Parent(s):
7621f1c
Update modeling_quiet.py
Browse files- modeling_quiet.py +5 -6
modeling_quiet.py
CHANGED
@@ -448,11 +448,10 @@ class QuietFlashAttention2(QuietAttention):
|
|
448 |
query_states = query_states.to(target_dtype)
|
449 |
key_states = key_states.to(target_dtype)
|
450 |
value_states = value_states.to(target_dtype)
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
value_states = value_states.transpose(1, 2)
|
456 |
|
457 |
attn_output = self._flash_attention_forward(
|
458 |
query_states,
|
@@ -462,7 +461,7 @@ class QuietFlashAttention2(QuietAttention):
|
|
462 |
q_len,
|
463 |
dropout=dropout_rate,
|
464 |
use_sliding_windows=use_sliding_windows,
|
465 |
-
|
466 |
|
467 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
468 |
attn_output = self.o_proj(attn_output)
|
|
|
448 |
query_states = query_states.to(target_dtype)
|
449 |
key_states = key_states.to(target_dtype)
|
450 |
value_states = value_states.to(target_dtype)
|
451 |
+
# Reshape to the expected shape for Flash Attention
|
452 |
+
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
|
453 |
+
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim)
|
454 |
+
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim)
|
|
|
455 |
|
456 |
attn_output = self._flash_attention_forward(
|
457 |
query_states,
|
|
|
461 |
q_len,
|
462 |
dropout=dropout_rate,
|
463 |
use_sliding_windows=use_sliding_windows,
|
464 |
+
)
|
465 |
|
466 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
467 |
attn_output = self.o_proj(attn_output)
|