Upload triton_flash_blocksparse_attn.py
#21
by
barcelosallan
- opened
- triton_flash_blocksparse_attn.py +58 -56
triton_flash_blocksparse_attn.py
CHANGED
@@ -611,30 +611,31 @@ def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BL
|
|
611 |
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
|
612 |
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
|
613 |
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
|
|
638 |
if inference:
|
639 |
L, m = None, None
|
640 |
|
@@ -991,37 +992,38 @@ def blocksparse_flash_attn_padded_fwd(
|
|
991 |
|
992 |
grid = (len(q_start_sids), n_heads)
|
993 |
|
994 |
-
|
995 |
-
|
996 |
-
|
997 |
-
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
-
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
|
|
1025 |
|
1026 |
return out
|
1027 |
|
@@ -1940,4 +1942,4 @@ if __name__ == '__main__':
|
|
1940 |
# 4 4096.0 3.401622 6.221376 1.636039
|
1941 |
# 5 8192.0 11.915136 23.483391 3.968725
|
1942 |
# 6 16384.0 44.660225 91.302910 10.857130
|
1943 |
-
# 7 32768.0 175.038467 359.048187 32.778240
|
|
|
611 |
# print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
|
612 |
# {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
|
613 |
|
614 |
+
with torch.cuda.device(q.device.index):
|
615 |
+
_fwd_kernel[grid](
|
616 |
+
q, k, v, sm_scale,
|
617 |
+
layout_crow_indices,
|
618 |
+
layout_col_indices,
|
619 |
+
layout_crow_indices.stride(0), layout_crow_indices.stride(1),
|
620 |
+
layout_col_indices.stride(0), layout_col_indices.stride(1),
|
621 |
+
tmp, L, m,
|
622 |
+
o,
|
623 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
624 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
625 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
626 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
627 |
+
q.shape[0], q.shape[1], k.shape[2],
|
628 |
+
k.shape[2] - q.shape[2],
|
629 |
+
q_rounded_len,
|
630 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
631 |
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
632 |
+
EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
|
633 |
+
EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
|
634 |
+
INFERENCE=inference,
|
635 |
+
NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
|
636 |
+
num_warps=num_warps,
|
637 |
+
num_stages=num_stages,
|
638 |
+
)
|
639 |
if inference:
|
640 |
L, m = None, None
|
641 |
|
|
|
992 |
|
993 |
grid = (len(q_start_sids), n_heads)
|
994 |
|
995 |
+
with torch.cuda.device(q.device.index):
|
996 |
+
_fwd_kernel_batch_inference[grid](
|
997 |
+
q, k, v, out,
|
998 |
+
sm_scale,
|
999 |
+
q_batch_starts,
|
1000 |
+
q_batch_ends,
|
1001 |
+
k_batch_starts,
|
1002 |
+
k_batch_ends,
|
1003 |
+
q_batch_ids,
|
1004 |
+
q_start_sids,
|
1005 |
+
|
1006 |
+
*q.stride(),
|
1007 |
+
*k.stride(),
|
1008 |
+
*v.stride(),
|
1009 |
+
*out.stride(),
|
1010 |
+
|
1011 |
+
layout_crow_indices,
|
1012 |
+
layout_col_indices,
|
1013 |
+
*layout_crow_indices.stride(),
|
1014 |
+
*layout_col_indices.stride(),
|
1015 |
+
|
1016 |
+
q_k_ratio,
|
1017 |
+
HAS_BATCH_DIM = True,
|
1018 |
+
D_HEAD = head_size,
|
1019 |
+
BLOCK_M = block_size,
|
1020 |
+
BLOCK_N = block_size,
|
1021 |
+
BLOCK_D = block_d,
|
1022 |
+
BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
|
1023 |
+
EVEN_D = block_d == head_size,
|
1024 |
+
num_warps = 1 if q_len == 1 else 4,
|
1025 |
+
num_stages = 3
|
1026 |
+
)
|
1027 |
|
1028 |
return out
|
1029 |
|
|
|
1942 |
# 4 4096.0 3.401622 6.221376 1.636039
|
1943 |
# 5 8192.0 11.915136 23.483391 3.968725
|
1944 |
# 6 16384.0 44.660225 91.302910 10.857130
|
1945 |
+
# 7 32768.0 175.038467 359.048187 32.778240
|