Update triton_flash_atn.py
Browse files- triton_flash_atn.py +184 -493
triton_flash_atn.py
CHANGED
@@ -25,27 +25,17 @@ import triton.language as tl
|
|
25 |
# TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
|
26 |
|
27 |
# AMD E5M2B16
|
28 |
-
TORCH_HAS_FP8E5B16 = hasattr(torch,
|
29 |
|
30 |
|
31 |
@triton.jit
|
32 |
-
def _attn_fwd_inner(
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
start_m,
|
40 |
-
BLOCK_M: tl.constexpr,
|
41 |
-
BLOCK_DMODEL: tl.constexpr,
|
42 |
-
BLOCK_N: tl.constexpr,
|
43 |
-
STAGE: tl.constexpr,
|
44 |
-
offs_m: tl.constexpr,
|
45 |
-
offs_n: tl.constexpr,
|
46 |
-
N_CTX,
|
47 |
-
pre_load_v: tl.constexpr,
|
48 |
-
):
|
49 |
# range of values handled by this stage
|
50 |
if STAGE == 1:
|
51 |
lo, hi = 0, start_m * BLOCK_M
|
@@ -93,119 +83,37 @@ def _attn_fwd_inner(
|
|
93 |
# re-tuning.
|
94 |
@triton.autotune(
|
95 |
configs=[
|
96 |
-
triton.Config(
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
"BLOCK_N": 16,
|
111 |
-
"waves_per_eu": 2,
|
112 |
-
"slice_k_tile": 32,
|
113 |
-
"pre_load_v": False,
|
114 |
-
},
|
115 |
-
num_stages=1,
|
116 |
-
num_warps=2,
|
117 |
-
),
|
118 |
-
triton.Config(
|
119 |
-
{
|
120 |
-
"BLOCK_M": 32,
|
121 |
-
"BLOCK_N": 32,
|
122 |
-
"waves_per_eu": 2,
|
123 |
-
"slice_k_tile": 0,
|
124 |
-
"pre_load_v": False,
|
125 |
-
},
|
126 |
-
num_stages=1,
|
127 |
-
num_warps=1,
|
128 |
-
),
|
129 |
-
triton.Config(
|
130 |
-
{
|
131 |
-
"BLOCK_M": 32,
|
132 |
-
"BLOCK_N": 32,
|
133 |
-
"waves_per_eu": 2,
|
134 |
-
"slice_k_tile": 32,
|
135 |
-
"pre_load_v": False,
|
136 |
-
},
|
137 |
-
num_stages=1,
|
138 |
-
num_warps=1,
|
139 |
-
),
|
140 |
-
triton.Config(
|
141 |
-
{
|
142 |
-
"BLOCK_M": 64,
|
143 |
-
"BLOCK_N": 32,
|
144 |
-
"waves_per_eu": 2,
|
145 |
-
"slice_k_tile": 0,
|
146 |
-
"pre_load_v": False,
|
147 |
-
},
|
148 |
-
num_stages=1,
|
149 |
-
num_warps=2,
|
150 |
-
),
|
151 |
-
triton.Config(
|
152 |
-
{
|
153 |
-
"BLOCK_M": 32,
|
154 |
-
"BLOCK_N": 16,
|
155 |
-
"waves_per_eu": 3,
|
156 |
-
"slice_k_tile": 0,
|
157 |
-
"pre_load_v": True,
|
158 |
-
},
|
159 |
-
num_stages=1,
|
160 |
-
num_warps=1,
|
161 |
-
),
|
162 |
-
triton.Config(
|
163 |
-
{
|
164 |
-
"BLOCK_M": 32,
|
165 |
-
"BLOCK_N": 16,
|
166 |
-
"waves_per_eu": 3,
|
167 |
-
"slice_k_tile": 0,
|
168 |
-
"pre_load_v": False,
|
169 |
-
},
|
170 |
-
num_stages=1,
|
171 |
-
num_warps=1,
|
172 |
-
),
|
173 |
],
|
174 |
-
key=[
|
175 |
)
|
176 |
@triton.jit
|
177 |
-
def _attn_fwd(
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
stride_kn,
|
191 |
-
stride_kk,
|
192 |
-
stride_vz,
|
193 |
-
stride_vh,
|
194 |
-
stride_vk,
|
195 |
-
stride_vn,
|
196 |
-
stride_oz,
|
197 |
-
stride_oh,
|
198 |
-
stride_om,
|
199 |
-
stride_on,
|
200 |
-
Z,
|
201 |
-
H,
|
202 |
-
N_CTX,
|
203 |
-
BLOCK_DMODEL: tl.constexpr,
|
204 |
-
STAGE: tl.constexpr,
|
205 |
-
BLOCK_M: tl.constexpr,
|
206 |
-
BLOCK_N: tl.constexpr,
|
207 |
-
pre_load_v: tl.constexpr,
|
208 |
-
):
|
209 |
start_m = tl.program_id(0)
|
210 |
off_hz = tl.program_id(1)
|
211 |
qvk_offset = off_hz * stride_qh
|
@@ -261,45 +169,23 @@ def _attn_fwd(
|
|
261 |
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
|
262 |
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
|
263 |
if STAGE & 1:
|
264 |
-
acc, l_i, m_i = _attn_fwd_inner(
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
V_block_ptr,
|
271 |
-
start_m,
|
272 |
-
BLOCK_M,
|
273 |
-
BLOCK_DMODEL,
|
274 |
-
BLOCK_N,
|
275 |
-
4 - STAGE,
|
276 |
-
offs_m,
|
277 |
-
offs_n,
|
278 |
-
N_CTX,
|
279 |
-
pre_load_v,
|
280 |
-
)
|
281 |
# stage 2: on-band
|
282 |
if STAGE & 2:
|
283 |
# barrier makes it easier for compielr to schedule the
|
284 |
# two loops independently
|
285 |
tl.debug_barrier()
|
286 |
-
acc, l_i, m_i = _attn_fwd_inner(
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
V_block_ptr,
|
293 |
-
start_m,
|
294 |
-
BLOCK_M,
|
295 |
-
BLOCK_DMODEL,
|
296 |
-
BLOCK_N,
|
297 |
-
2,
|
298 |
-
offs_m,
|
299 |
-
offs_n,
|
300 |
-
N_CTX,
|
301 |
-
pre_load_v,
|
302 |
-
)
|
303 |
# epilogue
|
304 |
# write back m
|
305 |
acc = acc / l_i[:, None]
|
@@ -309,46 +195,36 @@ def _attn_fwd(
|
|
309 |
|
310 |
|
311 |
@triton.jit
|
312 |
-
def _attn_bwd_preprocess(
|
313 |
-
|
314 |
-
|
|
|
|
|
315 |
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
316 |
off_hz = tl.program_id(1)
|
317 |
off_n = tl.arange(0, D_HEAD)
|
318 |
-
o = tl.load(O + off_hz * D_HEAD * N_CTX +
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
delta = tl.sum(o * do, axis=1)
|
323 |
tl.store(Delta + off_hz * N_CTX + off_m, delta)
|
324 |
|
325 |
|
326 |
# The main inner-loop logic for computing dK and dV.
|
327 |
@triton.jit
|
328 |
-
def _attn_bwd_dkdv(
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
stride_d,
|
341 |
-
H,
|
342 |
-
N_CTX,
|
343 |
-
BLOCK_M1: tl.constexpr,
|
344 |
-
BLOCK_N1: tl.constexpr,
|
345 |
-
BLOCK_DMODEL: tl.constexpr,
|
346 |
-
# Filled in by the wrapper.
|
347 |
-
start_n,
|
348 |
-
start_m,
|
349 |
-
num_steps,
|
350 |
-
MASK: tl.constexpr,
|
351 |
-
):
|
352 |
offs_m = start_m + tl.arange(0, BLOCK_M1)
|
353 |
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
354 |
offs_k = tl.arange(0, BLOCK_DMODEL)
|
@@ -358,7 +234,7 @@ def _attn_bwd_dkdv(
|
|
358 |
strides=(stride_d, stride_tok),
|
359 |
offsets=(0, start_m),
|
360 |
block_shape=(BLOCK_DMODEL, BLOCK_M1),
|
361 |
-
order=(0, 1)
|
362 |
)
|
363 |
DO_block_ptr = tl.make_block_ptr(
|
364 |
base=DO,
|
@@ -366,7 +242,7 @@ def _attn_bwd_dkdv(
|
|
366 |
strides=(stride_tok, stride_d),
|
367 |
offsets=(start_m, 0),
|
368 |
block_shape=(BLOCK_M1, BLOCK_DMODEL),
|
369 |
-
order=(1, 0)
|
370 |
)
|
371 |
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
372 |
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
@@ -381,7 +257,7 @@ def _attn_bwd_dkdv(
|
|
381 |
pT = tl.math.exp2(qkT - m[None, :])
|
382 |
# Autoregressive masking.
|
383 |
if MASK:
|
384 |
-
mask = offs_m[None, :] >= offs_n[:, None]
|
385 |
pT = tl.where(mask, pT, 0.0)
|
386 |
do = tl.load(DO_block_ptr)
|
387 |
# Compute dV.
|
@@ -404,28 +280,17 @@ def _attn_bwd_dkdv(
|
|
404 |
|
405 |
# the main inner-loop logic for computing dQ
|
406 |
@triton.jit
|
407 |
-
def _attn_bwd_dq(
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
H,
|
419 |
-
N_CTX,
|
420 |
-
BLOCK_M2: tl.constexpr,
|
421 |
-
BLOCK_N2: tl.constexpr,
|
422 |
-
BLOCK_DMODEL: tl.constexpr,
|
423 |
-
# Filled in by the wrapper.
|
424 |
-
start_m,
|
425 |
-
start_n,
|
426 |
-
num_steps,
|
427 |
-
MASK: tl.constexpr,
|
428 |
-
):
|
429 |
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
430 |
offs_n = start_n + tl.arange(0, BLOCK_N2)
|
431 |
offs_k = tl.arange(0, BLOCK_DMODEL)
|
@@ -435,7 +300,7 @@ def _attn_bwd_dq(
|
|
435 |
strides=(stride_d, stride_tok),
|
436 |
offsets=(0, start_n),
|
437 |
block_shape=(BLOCK_DMODEL, BLOCK_N2),
|
438 |
-
order=(0, 1)
|
439 |
)
|
440 |
VT_block_ptr = tl.make_block_ptr(
|
441 |
base=V,
|
@@ -443,7 +308,7 @@ def _attn_bwd_dq(
|
|
443 |
strides=(stride_d, stride_tok),
|
444 |
offsets=(0, start_n),
|
445 |
block_shape=(BLOCK_DMODEL, BLOCK_N2),
|
446 |
-
order=(0, 1)
|
447 |
)
|
448 |
# D (= delta) is pre-divided by ds_scale.
|
449 |
Di = tl.load(D + offs_m)
|
@@ -458,7 +323,7 @@ def _attn_bwd_dq(
|
|
458 |
# Autoregressive masking.
|
459 |
if MASK:
|
460 |
offs_n = curr_n + tl.arange(0, BLOCK_N2)
|
461 |
-
mask = offs_m[:, None] >= offs_n[None, :]
|
462 |
p = tl.where(mask, p, 0.0)
|
463 |
# Compute dP and dS.
|
464 |
vT = tl.load(VT_block_ptr)
|
@@ -477,135 +342,42 @@ def _attn_bwd_dq(
|
|
477 |
|
478 |
@triton.autotune(
|
479 |
configs=[
|
480 |
-
triton.Config(
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
},
|
499 |
-
num_stages=1,
|
500 |
-
num_warps=4,
|
501 |
-
),
|
502 |
-
triton.Config(
|
503 |
-
{
|
504 |
-
"BLOCK_M1": 64,
|
505 |
-
"BLOCK_N1": 128,
|
506 |
-
"BLOCK_M2": 128,
|
507 |
-
"BLOCK_N2": 64,
|
508 |
-
"BLK_SLICE_FACTOR": 1,
|
509 |
-
},
|
510 |
-
num_stages=1,
|
511 |
-
num_warps=4,
|
512 |
-
),
|
513 |
-
triton.Config(
|
514 |
-
{
|
515 |
-
"BLOCK_M1": 64,
|
516 |
-
"BLOCK_N1": 128,
|
517 |
-
"BLOCK_M2": 128,
|
518 |
-
"BLOCK_N2": 64,
|
519 |
-
"BLK_SLICE_FACTOR": 2,
|
520 |
-
},
|
521 |
-
num_stages=1,
|
522 |
-
num_warps=4,
|
523 |
-
),
|
524 |
-
triton.Config(
|
525 |
-
{
|
526 |
-
"BLOCK_M1": 64,
|
527 |
-
"BLOCK_N1": 64,
|
528 |
-
"BLOCK_M2": 64,
|
529 |
-
"BLOCK_N2": 64,
|
530 |
-
"BLK_SLICE_FACTOR": 1,
|
531 |
-
},
|
532 |
-
num_stages=1,
|
533 |
-
num_warps=4,
|
534 |
-
),
|
535 |
-
triton.Config(
|
536 |
-
{
|
537 |
-
"BLOCK_M1": 64,
|
538 |
-
"BLOCK_N1": 64,
|
539 |
-
"BLOCK_M2": 64,
|
540 |
-
"BLOCK_N2": 64,
|
541 |
-
"BLK_SLICE_FACTOR": 2,
|
542 |
-
},
|
543 |
-
num_stages=1,
|
544 |
-
num_warps=4,
|
545 |
-
),
|
546 |
-
triton.Config(
|
547 |
-
{
|
548 |
-
"BLOCK_M1": 32,
|
549 |
-
"BLOCK_N1": 128,
|
550 |
-
"BLOCK_M2": 128,
|
551 |
-
"BLOCK_N2": 32,
|
552 |
-
"BLK_SLICE_FACTOR": 1,
|
553 |
-
},
|
554 |
-
num_stages=1,
|
555 |
-
num_warps=4,
|
556 |
-
),
|
557 |
-
triton.Config(
|
558 |
-
{
|
559 |
-
"BLOCK_M1": 32,
|
560 |
-
"BLOCK_N1": 128,
|
561 |
-
"BLOCK_M2": 128,
|
562 |
-
"BLOCK_N2": 32,
|
563 |
-
"BLK_SLICE_FACTOR": 2,
|
564 |
-
},
|
565 |
-
num_stages=1,
|
566 |
-
num_warps=4,
|
567 |
-
),
|
568 |
-
triton.Config(
|
569 |
-
{
|
570 |
-
"BLOCK_M1": 32,
|
571 |
-
"BLOCK_N1": 128,
|
572 |
-
"BLOCK_M2": 128,
|
573 |
-
"BLOCK_N2": 32,
|
574 |
-
"BLK_SLICE_FACTOR": 2,
|
575 |
-
},
|
576 |
-
num_stages=1,
|
577 |
-
num_warps=8,
|
578 |
-
),
|
579 |
],
|
580 |
-
key=[
|
581 |
)
|
582 |
@triton.jit
|
583 |
-
def _attn_bwd(
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
stride_tok,
|
598 |
-
stride_d,
|
599 |
-
# H = 16, N_CTX = 1024
|
600 |
-
H,
|
601 |
-
N_CTX,
|
602 |
-
BLOCK_DMODEL: tl.constexpr,
|
603 |
-
BLOCK_M1: tl.constexpr,
|
604 |
-
BLOCK_N1: tl.constexpr,
|
605 |
-
BLOCK_M2: tl.constexpr,
|
606 |
-
BLOCK_N2: tl.constexpr,
|
607 |
-
BLK_SLICE_FACTOR: tl.constexpr,
|
608 |
-
):
|
609 |
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
|
610 |
|
611 |
bhid = tl.program_id(2)
|
@@ -661,54 +433,31 @@ def _attn_bwd(
|
|
661 |
|
662 |
num_steps = BLOCK_N1 // MASK_BLOCK_M1
|
663 |
|
664 |
-
dk, dv = _attn_bwd_dkdv(
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
stride_tok,
|
675 |
-
stride_d,
|
676 |
-
H,
|
677 |
-
N_CTX,
|
678 |
-
MASK_BLOCK_M1,
|
679 |
-
BLOCK_N1,
|
680 |
-
BLOCK_DMODEL,
|
681 |
-
start_n,
|
682 |
-
start_m,
|
683 |
-
num_steps,
|
684 |
-
MASK=True,
|
685 |
-
)
|
686 |
|
687 |
start_m += num_steps * MASK_BLOCK_M1
|
688 |
num_steps = (N_CTX - start_m) // BLOCK_M1
|
689 |
|
690 |
# Compute dK and dV for non-masked blocks.
|
691 |
dk, dv = _attn_bwd_dkdv(
|
692 |
-
dk,
|
693 |
-
|
694 |
-
Q,
|
695 |
-
k,
|
696 |
-
v,
|
697 |
-
sm_scale,
|
698 |
DO,
|
699 |
-
M,
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
BLOCK_M1,
|
706 |
-
BLOCK_N1,
|
707 |
-
BLOCK_DMODEL,
|
708 |
-
start_n,
|
709 |
-
start_m,
|
710 |
-
num_steps,
|
711 |
-
MASK=False,
|
712 |
)
|
713 |
|
714 |
DV_block_ptrs = tl.make_block_ptr(
|
@@ -717,7 +466,7 @@ def _attn_bwd(
|
|
717 |
strides=(stride_tok, stride_d),
|
718 |
offsets=(start_n, 0),
|
719 |
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
720 |
-
order=(1, 0)
|
721 |
)
|
722 |
tl.store(DV_block_ptrs, dv.to(tl.float16))
|
723 |
|
@@ -729,7 +478,7 @@ def _attn_bwd(
|
|
729 |
strides=(stride_tok, stride_d),
|
730 |
offsets=(start_n, 0),
|
731 |
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
732 |
-
order=(1, 0)
|
733 |
)
|
734 |
tl.store(DK_block_ptrs, dk.to(tl.float16))
|
735 |
|
@@ -746,7 +495,7 @@ def _attn_bwd(
|
|
746 |
strides=(stride_tok, stride_d),
|
747 |
offsets=(start_m, 0),
|
748 |
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
749 |
-
order=(1, 0)
|
750 |
)
|
751 |
|
752 |
DO_block_ptr = tl.make_block_ptr(
|
@@ -755,7 +504,7 @@ def _attn_bwd(
|
|
755 |
strides=(stride_tok, stride_d),
|
756 |
offsets=(start_m, 0),
|
757 |
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
758 |
-
order=(1, 0)
|
759 |
)
|
760 |
q = tl.load(Q_block_ptr)
|
761 |
do = tl.load(DO_block_ptr)
|
@@ -770,49 +519,25 @@ def _attn_bwd(
|
|
770 |
# not due to anything important. I just wanted to reuse the loop
|
771 |
# structure for dK & dV above as much as possible.
|
772 |
num_steps = BLOCK_M2 // MASK_BLOCK_N2
|
773 |
-
dq = _attn_bwd_dq(
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
stride_tok,
|
782 |
-
stride_d,
|
783 |
-
H,
|
784 |
-
N_CTX,
|
785 |
-
BLOCK_M2,
|
786 |
-
MASK_BLOCK_N2,
|
787 |
-
BLOCK_DMODEL,
|
788 |
-
start_m,
|
789 |
-
end_n - num_steps * MASK_BLOCK_N2,
|
790 |
-
num_steps,
|
791 |
-
MASK=True,
|
792 |
-
)
|
793 |
end_n -= num_steps * MASK_BLOCK_N2
|
794 |
# stage 2
|
795 |
num_steps = end_n // BLOCK_N2
|
796 |
-
dq = _attn_bwd_dq(
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
stride_tok,
|
805 |
-
stride_d,
|
806 |
-
H,
|
807 |
-
N_CTX,
|
808 |
-
BLOCK_M2,
|
809 |
-
BLOCK_N2,
|
810 |
-
BLOCK_DMODEL,
|
811 |
-
start_m,
|
812 |
-
end_n - num_steps * BLOCK_N2,
|
813 |
-
num_steps,
|
814 |
-
MASK=False,
|
815 |
-
)
|
816 |
# Write back dQ.
|
817 |
DQ_block_ptr = tl.make_block_ptr(
|
818 |
base=DQ,
|
@@ -820,7 +545,7 @@ def _attn_bwd(
|
|
820 |
strides=(stride_tok, stride_d),
|
821 |
offsets=(start_m, 0),
|
822 |
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
823 |
-
order=(1, 0)
|
824 |
)
|
825 |
dq *= LN2
|
826 |
tl.store(DQ_block_ptr, dq.to(tl.float16))
|
@@ -849,41 +574,20 @@ class _attention(torch.autograd.Function):
|
|
849 |
num_stages = 7 if Lk >= 64 else 3
|
850 |
stage = 3 if causal else 1
|
851 |
|
852 |
-
def grid(META):
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
1,
|
857 |
-
)
|
858 |
-
|
859 |
-
M = torch.empty(
|
860 |
-
(q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
|
861 |
)
|
|
|
|
|
862 |
_attn_fwd[grid](
|
863 |
-
q,
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
q.stride(0),
|
870 |
-
q.stride(1),
|
871 |
-
q.stride(2),
|
872 |
-
q.stride(3),
|
873 |
-
k.stride(0),
|
874 |
-
k.stride(1),
|
875 |
-
k.stride(2),
|
876 |
-
k.stride(3),
|
877 |
-
v.stride(0),
|
878 |
-
v.stride(1),
|
879 |
-
v.stride(2),
|
880 |
-
v.stride(3),
|
881 |
-
o.stride(0),
|
882 |
-
o.stride(1),
|
883 |
-
o.stride(2),
|
884 |
-
o.stride(3),
|
885 |
-
q.shape[0],
|
886 |
-
q.shape[1],
|
887 |
N_CTX=q.shape[2],
|
888 |
BLOCK_DMODEL=Lk,
|
889 |
STAGE=stage,
|
@@ -925,39 +629,26 @@ class _attention(torch.autograd.Function):
|
|
925 |
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
|
926 |
delta = torch.empty_like(M)
|
927 |
_attn_bwd_preprocess[pre_grid](
|
928 |
-
o,
|
929 |
-
do,
|
930 |
delta,
|
931 |
-
BATCH,
|
932 |
-
|
933 |
-
N_CTX,
|
934 |
-
BLOCK_M=PRE_BLOCK,
|
935 |
-
D_HEAD=ctx.BLOCK_DMODEL,
|
936 |
)
|
937 |
|
938 |
-
def grid(META):
|
939 |
-
|
940 |
-
|
|
|
|
|
941 |
_attn_bwd[grid](
|
942 |
-
q,
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
947 |
-
dq,
|
948 |
-
dk,
|
949 |
-
dv,
|
950 |
-
M,
|
951 |
-
delta,
|
952 |
-
q.stride(0),
|
953 |
-
q.stride(1),
|
954 |
-
q.stride(2),
|
955 |
-
q.stride(3),
|
956 |
-
N_HEAD,
|
957 |
-
N_CTX,
|
958 |
-
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
959 |
)
|
960 |
|
961 |
return dq, dk, dv, None, None
|
962 |
|
963 |
|
|
|
|
25 |
# TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
|
26 |
|
27 |
# AMD E5M2B16
|
28 |
+
TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz')
|
29 |
|
30 |
|
31 |
@triton.jit
|
32 |
+
def _attn_fwd_inner(acc, l_i, m_i, q,
|
33 |
+
K_block_ptr, V_block_ptr,
|
34 |
+
start_m,
|
35 |
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
|
36 |
+
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
|
37 |
+
N_CTX,
|
38 |
+
pre_load_v: tl.constexpr):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
# range of values handled by this stage
|
40 |
if STAGE == 1:
|
41 |
lo, hi = 0, start_m * BLOCK_M
|
|
|
83 |
# re-tuning.
|
84 |
@triton.autotune(
|
85 |
configs=[
|
86 |
+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
|
87 |
+
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
|
88 |
+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16, 'waves_per_eu': 2,
|
89 |
+
'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=2),
|
90 |
+
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
|
91 |
+
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
|
92 |
+
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2,
|
93 |
+
'slice_k_tile': 32, 'pre_load_v': False}, num_stages=1, num_warps=1),
|
94 |
+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'waves_per_eu': 2,
|
95 |
+
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=2),
|
96 |
+
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
|
97 |
+
'slice_k_tile': 0, 'pre_load_v': True}, num_stages=1, num_warps=1),
|
98 |
+
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 3,
|
99 |
+
'slice_k_tile': 0, 'pre_load_v': False}, num_stages=1, num_warps=1),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
],
|
101 |
+
key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'],
|
102 |
)
|
103 |
@triton.jit
|
104 |
+
def _attn_fwd(Q, K, V, sm_scale, M, Out,
|
105 |
+
stride_qz, stride_qh, stride_qm, stride_qk,
|
106 |
+
stride_kz, stride_kh, stride_kn, stride_kk,
|
107 |
+
stride_vz, stride_vh, stride_vk, stride_vn,
|
108 |
+
stride_oz, stride_oh, stride_om, stride_on,
|
109 |
+
Z, H,
|
110 |
+
N_CTX,
|
111 |
+
BLOCK_DMODEL: tl.constexpr,
|
112 |
+
STAGE: tl.constexpr,
|
113 |
+
BLOCK_M: tl.constexpr,
|
114 |
+
BLOCK_N: tl.constexpr,
|
115 |
+
pre_load_v: tl.constexpr,
|
116 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
start_m = tl.program_id(0)
|
118 |
off_hz = tl.program_id(1)
|
119 |
qvk_offset = off_hz * stride_qh
|
|
|
169 |
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
|
170 |
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
|
171 |
if STAGE & 1:
|
172 |
+
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
173 |
+
start_m,
|
174 |
+
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
175 |
+
4 - STAGE, offs_m, offs_n, N_CTX,
|
176 |
+
pre_load_v,
|
177 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
# stage 2: on-band
|
179 |
if STAGE & 2:
|
180 |
# barrier makes it easier for compielr to schedule the
|
181 |
# two loops independently
|
182 |
tl.debug_barrier()
|
183 |
+
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
184 |
+
start_m,
|
185 |
+
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
186 |
+
2, offs_m, offs_n, N_CTX,
|
187 |
+
pre_load_v,
|
188 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
# epilogue
|
190 |
# write back m
|
191 |
acc = acc / l_i[:, None]
|
|
|
195 |
|
196 |
|
197 |
@triton.jit
|
198 |
+
def _attn_bwd_preprocess(O, DO,
|
199 |
+
Delta,
|
200 |
+
Z, H, N_CTX,
|
201 |
+
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
|
202 |
+
):
|
203 |
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
204 |
off_hz = tl.program_id(1)
|
205 |
off_n = tl.arange(0, D_HEAD)
|
206 |
+
o = tl.load(O + off_hz * D_HEAD * N_CTX +
|
207 |
+
off_m[:, None] * D_HEAD + off_n[None, :])
|
208 |
+
do = tl.load(DO + off_hz * D_HEAD * N_CTX +
|
209 |
+
off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
210 |
delta = tl.sum(o * do, axis=1)
|
211 |
tl.store(Delta + off_hz * N_CTX + off_m, delta)
|
212 |
|
213 |
|
214 |
# The main inner-loop logic for computing dK and dV.
|
215 |
@triton.jit
|
216 |
+
def _attn_bwd_dkdv(dk, dv,
|
217 |
+
Q, k, v, sm_scale,
|
218 |
+
DO,
|
219 |
+
M, D,
|
220 |
+
# shared by Q/K/V/DO.
|
221 |
+
stride_tok, stride_d,
|
222 |
+
H, N_CTX, BLOCK_M1: tl.constexpr,
|
223 |
+
BLOCK_N1: tl.constexpr,
|
224 |
+
BLOCK_DMODEL: tl.constexpr,
|
225 |
+
# Filled in by the wrapper.
|
226 |
+
start_n, start_m, num_steps,
|
227 |
+
MASK: tl.constexpr):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
offs_m = start_m + tl.arange(0, BLOCK_M1)
|
229 |
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
230 |
offs_k = tl.arange(0, BLOCK_DMODEL)
|
|
|
234 |
strides=(stride_d, stride_tok),
|
235 |
offsets=(0, start_m),
|
236 |
block_shape=(BLOCK_DMODEL, BLOCK_M1),
|
237 |
+
order=(0, 1)
|
238 |
)
|
239 |
DO_block_ptr = tl.make_block_ptr(
|
240 |
base=DO,
|
|
|
242 |
strides=(stride_tok, stride_d),
|
243 |
offsets=(start_m, 0),
|
244 |
block_shape=(BLOCK_M1, BLOCK_DMODEL),
|
245 |
+
order=(1, 0)
|
246 |
)
|
247 |
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
248 |
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
|
|
257 |
pT = tl.math.exp2(qkT - m[None, :])
|
258 |
# Autoregressive masking.
|
259 |
if MASK:
|
260 |
+
mask = (offs_m[None, :] >= offs_n[:, None])
|
261 |
pT = tl.where(mask, pT, 0.0)
|
262 |
do = tl.load(DO_block_ptr)
|
263 |
# Compute dV.
|
|
|
280 |
|
281 |
# the main inner-loop logic for computing dQ
|
282 |
@triton.jit
|
283 |
+
def _attn_bwd_dq(dq, q, K, V,
|
284 |
+
do, m, D,
|
285 |
+
# shared by Q/K/V/DO.
|
286 |
+
stride_tok, stride_d,
|
287 |
+
H, N_CTX,
|
288 |
+
BLOCK_M2: tl.constexpr,
|
289 |
+
BLOCK_N2: tl.constexpr,
|
290 |
+
BLOCK_DMODEL: tl.constexpr,
|
291 |
+
# Filled in by the wrapper.
|
292 |
+
start_m, start_n, num_steps,
|
293 |
+
MASK: tl.constexpr):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
295 |
offs_n = start_n + tl.arange(0, BLOCK_N2)
|
296 |
offs_k = tl.arange(0, BLOCK_DMODEL)
|
|
|
300 |
strides=(stride_d, stride_tok),
|
301 |
offsets=(0, start_n),
|
302 |
block_shape=(BLOCK_DMODEL, BLOCK_N2),
|
303 |
+
order=(0, 1)
|
304 |
)
|
305 |
VT_block_ptr = tl.make_block_ptr(
|
306 |
base=V,
|
|
|
308 |
strides=(stride_d, stride_tok),
|
309 |
offsets=(0, start_n),
|
310 |
block_shape=(BLOCK_DMODEL, BLOCK_N2),
|
311 |
+
order=(0, 1)
|
312 |
)
|
313 |
# D (= delta) is pre-divided by ds_scale.
|
314 |
Di = tl.load(D + offs_m)
|
|
|
323 |
# Autoregressive masking.
|
324 |
if MASK:
|
325 |
offs_n = curr_n + tl.arange(0, BLOCK_N2)
|
326 |
+
mask = (offs_m[:, None] >= offs_n[None, :])
|
327 |
p = tl.where(mask, p, 0.0)
|
328 |
# Compute dP and dS.
|
329 |
vT = tl.load(VT_block_ptr)
|
|
|
342 |
|
343 |
@triton.autotune(
|
344 |
configs=[
|
345 |
+
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
|
346 |
+
num_stages=1, num_warps=4),
|
347 |
+
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
|
348 |
+
num_stages=1, num_warps=4),
|
349 |
+
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
|
350 |
+
num_stages=1, num_warps=4),
|
351 |
+
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
|
352 |
+
num_stages=1, num_warps=4),
|
353 |
+
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
|
354 |
+
num_stages=1, num_warps=4),
|
355 |
+
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
|
356 |
+
num_stages=1, num_warps=4),
|
357 |
+
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
|
358 |
+
num_stages=1, num_warps=4),
|
359 |
+
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
|
360 |
+
num_stages=1, num_warps=4),
|
361 |
+
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
|
362 |
+
num_stages=1, num_warps=8),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
],
|
364 |
+
key=['H', 'N_CTX', 'BLOCK_DMODEL'],
|
365 |
)
|
366 |
@triton.jit
|
367 |
+
def _attn_bwd(Q, K, V, sm_scale,
|
368 |
+
DO,
|
369 |
+
DQ, DK, DV,
|
370 |
+
M, D,
|
371 |
+
# shared by Q/K/V/DO.
|
372 |
+
stride_z, stride_h, stride_tok, stride_d,
|
373 |
+
# H = 16, N_CTX = 1024
|
374 |
+
H, N_CTX,
|
375 |
+
BLOCK_DMODEL: tl.constexpr,
|
376 |
+
BLOCK_M1: tl.constexpr,
|
377 |
+
BLOCK_N1: tl.constexpr,
|
378 |
+
BLOCK_M2: tl.constexpr,
|
379 |
+
BLOCK_N2: tl.constexpr,
|
380 |
+
BLK_SLICE_FACTOR: tl.constexpr):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
|
382 |
|
383 |
bhid = tl.program_id(2)
|
|
|
433 |
|
434 |
num_steps = BLOCK_N1 // MASK_BLOCK_M1
|
435 |
|
436 |
+
dk, dv = _attn_bwd_dkdv(dk, dv,
|
437 |
+
Q, k, v, sm_scale,
|
438 |
+
DO,
|
439 |
+
M, D,
|
440 |
+
stride_tok, stride_d,
|
441 |
+
H, N_CTX,
|
442 |
+
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
443 |
+
start_n, start_m, num_steps,
|
444 |
+
MASK=True
|
445 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
|
447 |
start_m += num_steps * MASK_BLOCK_M1
|
448 |
num_steps = (N_CTX - start_m) // BLOCK_M1
|
449 |
|
450 |
# Compute dK and dV for non-masked blocks.
|
451 |
dk, dv = _attn_bwd_dkdv(
|
452 |
+
dk, dv,
|
453 |
+
Q, k, v, sm_scale,
|
|
|
|
|
|
|
|
|
454 |
DO,
|
455 |
+
M, D,
|
456 |
+
stride_tok, stride_d,
|
457 |
+
H, N_CTX,
|
458 |
+
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
459 |
+
start_n, start_m, num_steps,
|
460 |
+
MASK=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
)
|
462 |
|
463 |
DV_block_ptrs = tl.make_block_ptr(
|
|
|
466 |
strides=(stride_tok, stride_d),
|
467 |
offsets=(start_n, 0),
|
468 |
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
469 |
+
order=(1, 0)
|
470 |
)
|
471 |
tl.store(DV_block_ptrs, dv.to(tl.float16))
|
472 |
|
|
|
478 |
strides=(stride_tok, stride_d),
|
479 |
offsets=(start_n, 0),
|
480 |
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
481 |
+
order=(1, 0)
|
482 |
)
|
483 |
tl.store(DK_block_ptrs, dk.to(tl.float16))
|
484 |
|
|
|
495 |
strides=(stride_tok, stride_d),
|
496 |
offsets=(start_m, 0),
|
497 |
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
498 |
+
order=(1, 0)
|
499 |
)
|
500 |
|
501 |
DO_block_ptr = tl.make_block_ptr(
|
|
|
504 |
strides=(stride_tok, stride_d),
|
505 |
offsets=(start_m, 0),
|
506 |
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
507 |
+
order=(1, 0)
|
508 |
)
|
509 |
q = tl.load(Q_block_ptr)
|
510 |
do = tl.load(DO_block_ptr)
|
|
|
519 |
# not due to anything important. I just wanted to reuse the loop
|
520 |
# structure for dK & dV above as much as possible.
|
521 |
num_steps = BLOCK_M2 // MASK_BLOCK_N2
|
522 |
+
dq = _attn_bwd_dq(dq, q, K, V,
|
523 |
+
do, m, D,
|
524 |
+
stride_tok, stride_d,
|
525 |
+
H, N_CTX,
|
526 |
+
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
|
527 |
+
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
|
528 |
+
MASK=True
|
529 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
end_n -= num_steps * MASK_BLOCK_N2
|
531 |
# stage 2
|
532 |
num_steps = end_n // BLOCK_N2
|
533 |
+
dq = _attn_bwd_dq(dq, q, K, V,
|
534 |
+
do, m, D,
|
535 |
+
stride_tok, stride_d,
|
536 |
+
H, N_CTX,
|
537 |
+
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
|
538 |
+
start_m, end_n - num_steps * BLOCK_N2, num_steps,
|
539 |
+
MASK=False
|
540 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
# Write back dQ.
|
542 |
DQ_block_ptr = tl.make_block_ptr(
|
543 |
base=DQ,
|
|
|
545 |
strides=(stride_tok, stride_d),
|
546 |
offsets=(start_m, 0),
|
547 |
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
548 |
+
order=(1, 0)
|
549 |
)
|
550 |
dq *= LN2
|
551 |
tl.store(DQ_block_ptr, dq.to(tl.float16))
|
|
|
574 |
num_stages = 7 if Lk >= 64 else 3
|
575 |
stage = 3 if causal else 1
|
576 |
|
577 |
+
def grid(META): return (
|
578 |
+
triton.cdiv(q.shape[2], META['BLOCK_M']),
|
579 |
+
q.shape[0] * q.shape[1],
|
580 |
+
1
|
|
|
|
|
|
|
|
|
|
|
581 |
)
|
582 |
+
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]),
|
583 |
+
device=q.device, dtype=torch.float32)
|
584 |
_attn_fwd[grid](
|
585 |
+
q, k, v, sm_scale, M, o,
|
586 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
587 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
588 |
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
589 |
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
590 |
+
q.shape[0], q.shape[1],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
591 |
N_CTX=q.shape[2],
|
592 |
BLOCK_DMODEL=Lk,
|
593 |
STAGE=stage,
|
|
|
629 |
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
|
630 |
delta = torch.empty_like(M)
|
631 |
_attn_bwd_preprocess[pre_grid](
|
632 |
+
o, do,
|
|
|
633 |
delta,
|
634 |
+
BATCH, N_HEAD, N_CTX,
|
635 |
+
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL
|
|
|
|
|
|
|
636 |
)
|
637 |
|
638 |
+
def grid(META): return (
|
639 |
+
triton.cdiv(N_CTX, META['BLOCK_N1']),
|
640 |
+
1,
|
641 |
+
BATCH * N_HEAD
|
642 |
+
)
|
643 |
_attn_bwd[grid](
|
644 |
+
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
|
645 |
+
M, delta,
|
646 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
647 |
+
N_HEAD, N_CTX,
|
648 |
+
BLOCK_DMODEL=ctx.BLOCK_DMODEL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
)
|
650 |
|
651 |
return dq, dk, dv, None, None
|
652 |
|
653 |
|
654 |
+
attention = _attention.apply
|