FAT5-xl-flan-en / fa2_compilable.py
bourdoiscatie's picture
Upload 10 files
4f41cdf verified
raw
history blame
24.7 kB
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Union
import torch
import torch.nn as nn
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
# isort: on
torch.library.define("fa2::fwd", "(Tensor q, Tensor k, Tensor v, Tensor out, Tensor alibi_slopes, float dropout_p, float softmax_scale, bool causal, int window_size_left, int window_size_right, Tensor attn_bias, bool return_softmax, Tensor gen_) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)")
@torch.library.impl("fa2::fwd", "default")
def cuda_fa2_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
alibi_slopes: torch.Tensor,
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
attn_bias: torch.Tensor,
return_softmax: bool,
gen_: torch.Tensor,
):
out, q, k, v, out_padded, attn_bias, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(q, k, v, out, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left, window_size_right, attn_bias, return_softmax, None)
return out, q, k, v, out_padded, attn_bias, softmax_lse, S_dmask, rng_state
@torch.library.impl_abstract("fa2::fwd", cuda_fa2_fwd)
def meta_fa2_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
alibi_slopes: torch.Tensor,
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
attn_bias: torch.Tensor,
return_softmax: bool,
gen_: torch.Tensor
):
round_multiple = lambda x, m: (x + m - 1) // m * m
batch_size = q.shape[0]
seqlen_q = q.shape[1]
seqlen_k = k.shape[1]
num_heads = q.shape[2]
head_dim_og = q.shape[3]
seqlen_q_rounded = round_multiple(seqlen_q, 128)
seqlen_k_rounded = round_multiple(seqlen_k, 128)
seqlen_q_rounded_8 = round_multiple(seqlen_q, 8)
seqlen_k_rounded_8 = round_multiple(seqlen_k, 8)
head_dim = round_multiple(head_dim_og, 8)
if attn_bias is not None:
batch_size_bias = attn_bias.shape[0]
num_heads_bias = attn_bias.shape[1]
return (torch.empty_strided((batch_size, seqlen_q, num_heads, head_dim_og),
(head_dim*num_heads*seqlen_q, head_dim*num_heads, head_dim, 1), device=q.device, dtype=q.dtype), # out
q.new_empty((batch_size, seqlen_q, num_heads, head_dim)), # q_padded
k.new_empty((batch_size, seqlen_k, num_heads, head_dim)), # k_padded
v.new_empty((batch_size, seqlen_k, num_heads, head_dim)), # v_padded
q.new_empty((batch_size, seqlen_q, num_heads, head_dim)), # out_padded
q.new_empty((batch_size_bias, num_heads_bias, seqlen_q_rounded_8, seqlen_k_rounded_8)) if attn_bias is not None else None, # attn_bias
q.new_empty((batch_size, num_heads, seqlen_q)), # softmax_lse
q.new_empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded)) if return_softmax and (dropout_p > 0) else None, # p
torch.empty((2), dtype=torch.int64, device=q.device) # rng_state
)
torch.library.define("fa2::bwd", "(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor dq, Tensor dk, Tensor dv, Tensor alibi_slopes, float dropout_p, float softmax_scale, bool causal, int window_size_left, int window_size_right, bool deterministic, Tensor attn_bias, bool attn_bias_require_grad, Tensor ds, int seqlen_k_orig, Tensor gen_, Tensor rng_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")
@torch.library.impl("fa2::bwd", "default")
def cuda_fa2_bwd(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
dq: torch.Tensor,
dk: torch.Tensor,
dv: torch.Tensor,
alibi_slopes: torch.Tensor,
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
deterministic: bool,
attn_bias: torch.Tensor,
attn_bias_require_grad: bool,
ds: torch.Tensor,
seqlen_k_orig: int,
gen_: torch.Tensor,
rng_sate: torch.Tensor
):
dq, dk, dv, ds, s = flash_attn_cuda.bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left, window_size_right, deterministic, attn_bias, attn_bias_require_grad, ds, None, rng_sate)
return dq, dk, dv, ds, s
@torch.library.impl_abstract("fa2::bwd", cuda_fa2_bwd)
def meta_fa2_bwd(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
dq: torch.Tensor,
dk: torch.Tensor,
dv: torch.Tensor,
alibi_slopes: torch.Tensor,
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
deterministic: bool,
attn_bias: torch.Tensor,
attn_bias_require_grad: bool,
ds: torch.Tensor,
seqlen_k_orig: int,
gen_: torch.Tensor,
rng_sate: torch.Tensor
):
round_multiple = lambda x, m: (x + m - 1) // m * m
batch_size = dout.shape[0]
seqlen_q = dout.shape[1]
seqlen_k = k.shape[1]
seqlen_q_rounded = round_multiple(seqlen_q, 128)
num_heads = dout.shape[2]
head_dim_og = dout.shape[3]
head_dim = round_multiple(head_dim_og, 8)
seqlen_q_round8 = round_multiple(seqlen_q, 8)
seqlen_k_round8 = round_multiple(seqlen_k_orig, 8)
if attn_bias is not None:
batch_size_bias = attn_bias.shape[0]
num_heads_bias = attn_bias.shape[1]
return (torch.empty_strided((batch_size, seqlen_q, num_heads, head_dim_og),
(head_dim*num_heads*seqlen_q, head_dim*num_heads, head_dim, 1), device=q.device, dtype=q.dtype),
torch.empty_strided((batch_size, seqlen_k_orig, num_heads, head_dim_og),
(head_dim*num_heads*seqlen_k, head_dim*num_heads, head_dim, 1), device=k.device, dtype=k.dtype),
torch.empty_strided((batch_size, seqlen_k, num_heads, head_dim_og),
(head_dim*num_heads*seqlen_k, head_dim*num_heads, head_dim, 1), device=v.device, dtype=v.dtype),
torch.empty_strided((batch_size_bias, num_heads_bias, seqlen_q, seqlen_k_orig),
(num_heads_bias*seqlen_q_round8*seqlen_k_round8, seqlen_q_round8*seqlen_k_round8, seqlen_q_round8, 1), device=v.device, dtype=v.dtype)
if attn_bias_require_grad else None,
q.new_empty((batch_size, num_heads, seqlen_q_rounded))
)
class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
alibi_slopes,
deterministic,
attn_bias,
return_softmax,
return_ds
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, q_padded, k_padded, v_padded, out_padded, attn_bias_padded, softmax_lse, S_dmask, rng_state = torch.ops.fa2.fwd(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
None,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
attn_bias,
return_softmax and dropout_p > 0,
None
)
## WORKAROUND a Pytorch bug, should use _padded version of the tensors but this is buggy when passing them directly to save_for_backward
## For now, this breaks the backward when headdim is not a multiple of 8 and/or seqlen_q, seqlen_k are not a multiple of 8
## TODO: make the padding here instead
ctx.save_for_backward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], out, softmax_lse, rng_state, attn_bias, alibi_slopes)
#ctx.save_for_backward(q_padded, k_padded, v_padded, out_padded, softmax_lse, rng_state, attn_bias_padded, alibi_slopes)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size_left = window_size_left
ctx.window_size_right = window_size_right
ctx.deterministic = deterministic
ctx.bias_requires_grad = True if attn_bias is not None and return_ds else False
ctx.seqlen_k_orig = qkv.shape[1]
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes = ctx.saved_tensors
dq, dk, dv, ds, _ = torch.ops.fa2.bwd(
dout,
q,
k,
v,
out,
softmax_lse,
None,
None,
None,
alibi_slopes,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
ctx.window_size_left,
ctx.window_size_right,
ctx.deterministic,
attn_bias,
ctx.bias_requires_grad,
None,
ctx.seqlen_k_orig,
None,
rng_state
)
dqkv = torch.stack([dq, dk, dv], dim=2)
return dqkv, None, None, None, None, None, None, None, ds, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
kv,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
alibi_slopes,
deterministic,
attn_bias,
return_softmax,
return_ds
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q_padded, k_padded, v_padded, out_padded, attn_bias_padded, softmax_lse, S_dmask, rng_state = torch.ops.fa2.fwd(
q,
kv[:, :, 0],
kv[:, :, 1],
None,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
attn_bias,
return_softmax and dropout_p > 0,
None
)
## WORKAROUND a Pytorch bug, should use _padded version of the tensors but this is buggy when passing them directly to save_for_backward
## For now, this breaks the backward when headdim is not a multiple of 8 and/or seqlen_q, seqlen_k are not a multiple of 8
## TODO: make the padding here instead
ctx.save_for_backward(q, kv[:, :, 0], kv[:, :, 1], out, softmax_lse, rng_state, attn_bias, alibi_slopes)
#ctx.save_for_backward(q_padded, k_padded, v_padded, out_padded, softmax_lse, rng_state, attn_bias_padded, alibi_slopes)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size_left = window_size_left
ctx.window_size_right = window_size_right
ctx.deterministic = deterministic
ctx.bias_requires_grad = True if attn_bias is not None and return_ds else False
ctx.seqlen_k_orig = kv.shape[1]
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes = ctx.saved_tensors
dq, dk, dv, ds, _ = torch.ops.fa2.bwd(
dout,
q,
k,
v,
out,
softmax_lse,
None,
None,
None,
alibi_slopes,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
ctx.window_size_left,
ctx.window_size_right,
ctx.deterministic,
attn_bias,
ctx.bias_requires_grad,
None,
ctx.seqlen_k_orig,
None,
rng_state
)
dkv = torch.stack([dk, dv], dim=2)
return dq, dkv, None, None, None, None, None, None, None, ds, None, None
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
alibi_slopes,
deterministic,
attn_bias,
return_softmax,
return_ds
):
batch_size, seqlen_q = q.shape[:2]
seqlen_k = k.shape[1]
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if attn_bias is not None:
attn_bias = attn_bias.to(q.dtype)
out, q_padded, k_padded, v_padded, out_padded, attn_bias_padded, softmax_lse, S_dmask, rng_state = torch.ops.fa2.fwd(
q,
k,
v,
None,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
attn_bias,
return_softmax and dropout_p > 0,
None
)
## WORKAROUND a Pytorch bug, should use _padded version of the tensors but this is buggy when passing them directly to save_for_backward
## For now, this breaks the backward when headdim is not a multiple of 8 and/or seqlen_q, seqlen_k are not a multiple of 8
## TODO: make the padding here instead
ctx.save_for_backward(q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes)
#ctx.save_for_backward(q_padded, k_padded, v_padded, out_padded, softmax_lse, rng_state, attn_bias_padded, alibi_slopes)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size_left = window_size_left
ctx.window_size_right = window_size_right
ctx.deterministic = deterministic
ctx.bias_requires_grad = True if attn_bias is not None and return_ds else False
ctx.seqlen_k_orig = k.shape[1]
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes = ctx.saved_tensors
dout = dout.contiguous()
dq, dk, dv, ds, _ = torch.ops.fa2.bwd(
dout,
q,
k,
v,
out,
softmax_lse,
None,
None,
None,
alibi_slopes,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
ctx.window_size_left,
ctx.window_size_right,
ctx.deterministic,
attn_bias,
ctx.bias_requires_grad,
None,
ctx.seqlen_k_orig,
None,
rng_state
)
return dq, dk, dv, None, None, None, None, None, None, None, ds, None, None
def flash_attn_qkvpacked_func(
qkv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size_left=-1,
window_size_right=-1, # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
attn_bias=None,
return_attn_probs=False,
return_ds=False
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_kvpacked_func and flash_attn_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnQKVPackedFunc.apply(
qkv,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
alibi_slopes,
deterministic,
attn_bias,
return_attn_probs,
return_ds
)
def flash_attn_kvpacked_func(
q,
kv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size_left=-1,
window_size_right=-1, # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
attn_bias=None,
return_attn_probs=False,
return_ds=False
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnKVPackedFunc.apply(
q,
kv,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
alibi_slopes,
deterministic,
attn_bias,
return_attn_probs,
return_ds
)
def flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size_left=-1,
window_size_right=-1, # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
attn_bias=None,
return_attn_probs=False,
return_ds=False
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
alibi_slopes,
deterministic,
attn_bias,
return_attn_probs,
return_ds,
)