|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import triton |
|
import triton.language as tl |
|
import torch |
|
|
|
|
|
MAX_FUSED_SIZE = 65536 |
|
next_power_of_2 = triton.next_power_of_2 |
|
|
|
def calculate_settings(n): |
|
BLOCK_SIZE = next_power_of_2(n) |
|
if BLOCK_SIZE > MAX_FUSED_SIZE: |
|
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\ |
|
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") |
|
num_warps = 4 |
|
if BLOCK_SIZE >= 32768: num_warps = 32 |
|
elif BLOCK_SIZE >= 8192: num_warps = 16 |
|
elif BLOCK_SIZE >= 2048: num_warps = 8 |
|
return BLOCK_SIZE, num_warps |
|
|
|
|
|
@triton.jit |
|
def _rms_layernorm_forward( |
|
Y, Y_row_stride, |
|
X, X_row_stride, |
|
W, W_row_stride, |
|
r, r_row_stride, |
|
n_cols, eps, |
|
BLOCK_SIZE : tl.constexpr, |
|
IS_EVEN_X: tl.constexpr |
|
): |
|
""" |
|
Fast RMS Layernorm kernel |
|
Inspiration from a Triton tutorial: |
|
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html |
|
""" |
|
row_idx = tl.program_id(0) |
|
col_offsets = tl.arange(0, BLOCK_SIZE) |
|
mask = col_offsets < n_cols |
|
|
|
Y += row_idx * Y_row_stride |
|
X += row_idx * X_row_stride |
|
r += row_idx * r_row_stride |
|
|
|
if IS_EVEN_X: |
|
X_row = tl.load(X + col_offsets).to(tl.float32) |
|
W_row = tl.load(W + col_offsets) |
|
else: |
|
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32) |
|
W_row = tl.load(W + col_offsets, mask=mask, other=0) |
|
|
|
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols |
|
inv_var = tl.math.rsqrt(row_var + eps) |
|
tl.store(r, inv_var) |
|
normed = X_row * inv_var |
|
normed = normed.to(W_row.dtype) |
|
output = normed * W_row |
|
|
|
if IS_EVEN_X: |
|
tl.store(Y + col_offsets, output) |
|
else: |
|
tl.store(Y + col_offsets, output, mask=mask) |
|
|
|
@triton.jit |
|
def _rms_layernorm_backward( |
|
dY, dY_row_stride, |
|
X, X_row_stride, |
|
W, W_row_stride, |
|
r, r_row_stride, |
|
dW, dW_row_stride, |
|
dX, dX_row_stride, |
|
n_cols, eps, |
|
BLOCK_SIZE : tl.constexpr, |
|
IS_EVEN_X: tl.constexpr |
|
): |
|
""" |
|
Fast RMS Layernorm kernel for the backward pass |
|
Inspiration from a Triton tutorial: |
|
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html |
|
""" |
|
row_idx = tl.program_id(0) |
|
col_offsets = tl.arange(0, BLOCK_SIZE) |
|
mask = col_offsets < n_cols |
|
|
|
dY += row_idx * dY_row_stride |
|
X += row_idx * X_row_stride |
|
r += row_idx * r_row_stride |
|
dW += row_idx * dW_row_stride |
|
dX += row_idx * dX_row_stride |
|
|
|
if IS_EVEN_X: |
|
dY_row = tl.load(dY + col_offsets).to(tl.float32) |
|
X_row = tl.load(X + col_offsets).to(tl.float32) |
|
W_row = tl.load(W + col_offsets).to(tl.float32) |
|
else: |
|
dY_row = tl.load(dY + col_offsets, mask=mask, other=0).to(tl.float32) |
|
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32) |
|
W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32) |
|
|
|
|
|
inv_var = tl.load(r).to(tl.float32) |
|
normed = X_row * inv_var |
|
dW_row = dY_row * normed |
|
|
|
dY_W = dY_row * W_row |
|
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0) |
|
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) |
|
|
|
if IS_EVEN_X: |
|
tl.store(dW + col_offsets, dW_row) |
|
tl.store(dX + col_offsets, output) |
|
else: |
|
tl.store(dW + col_offsets, dW_row, mask=mask) |
|
tl.store(dX + col_offsets, output, mask=mask) |
|
|
|
|
|
|
|
torch.library.define("flasht5::rmsnorm_triton_fwd", "(Tensor X, Tensor W, float eps, int n_cols, int n_rows, int BLOCK_SIZE, int num_warps) -> (Tensor, Tensor)") |
|
|
|
@torch.library.impl("flasht5::rmsnorm_triton_fwd", "default") |
|
def rmsnorm_triton_fwd(X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps): |
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device="cuda") |
|
r = torch.empty(n_rows, dtype=torch.float32, device="cuda") |
|
|
|
_rms_layernorm_forward[(n_rows,)]( |
|
Y, Y.stride(0), |
|
X, X.stride(0), |
|
W, W.stride(0), |
|
r, r.stride(0), |
|
n_cols, eps, |
|
BLOCK_SIZE=BLOCK_SIZE, |
|
IS_EVEN_X=((n_cols % BLOCK_SIZE) == 0), |
|
num_warps=num_warps |
|
) |
|
|
|
return Y, r |
|
|
|
|
|
@torch.library.impl_abstract("flasht5::rmsnorm_triton_fwd", rmsnorm_triton_fwd) |
|
def rmsnorm_triton_fwd_abstract(X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps): |
|
Y = X.new_empty((n_rows, n_cols)) |
|
r = X.new_empty((n_rows)) |
|
return Y, r |
|
|
|
torch.library.define("flasht5::rmsnorm_triton_bwd", "(Tensor dY, Tensor r, Tensor X, Tensor W, float eps, int n_cols, int n_rows, int BLOCK_SIZE, int num_warps) -> (Tensor, Tensor)") |
|
|
|
@torch.library.impl("flasht5::rmsnorm_triton_bwd", "default") |
|
def rmsnorm_triton_bwd(dY, r, X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps): |
|
|
|
dX = torch.empty_like(dY) |
|
dW = torch.empty_like(dY) |
|
|
|
_rms_layernorm_backward[(n_rows,)]( |
|
dY, dY.stride(0), |
|
X, X.stride(0), |
|
W, 1, |
|
r, 1, |
|
dW, dW.stride(0), |
|
dX, dX.stride(0), |
|
n_cols, eps, |
|
BLOCK_SIZE=BLOCK_SIZE, |
|
IS_EVEN_X=((n_cols % BLOCK_SIZE) == 0), |
|
num_warps=num_warps, |
|
) |
|
|
|
return dX, dW |
|
|
|
|
|
@torch.library.impl_abstract("flasht5::rmsnorm_triton_bwd", rmsnorm_triton_bwd) |
|
def rmsnorm_triton_bwd_abstract(dY, r, X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps): |
|
return torch.empty_like(dY), torch.empty_like(dY) |
|
|
|
|
|
class Fast_RMS_Layernorm(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, X, W, eps): |
|
shape = X.shape |
|
dim = shape[-1] |
|
X = X.view(-1, dim) |
|
n_rows, n_cols = X.shape |
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols) |
|
|
|
Y, r = torch.ops.flasht5.rmsnorm_triton_fwd(X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps) |
|
|
|
ctx.eps = eps |
|
ctx.BLOCK_SIZE = BLOCK_SIZE |
|
ctx.num_warps = num_warps |
|
ctx.save_for_backward(X, W, r) |
|
return Y.view(*shape) |
|
|
|
@staticmethod |
|
def backward(ctx, dY): |
|
shape = dY.shape |
|
dim = shape[-1] |
|
dY = dY.view(-1, dim) |
|
X, W, r = ctx.saved_tensors |
|
n_rows, n_cols = dY.shape |
|
dX = torch.empty_like(dY) |
|
dW = torch.empty_like(dY) |
|
|
|
dW, dX = torch.ops.flasht5.rmsnorm_triton_bwd(dY, r, X, W, ctx.eps, n_cols, n_rows, ctx.BLOCK_SIZE, ctx.num_warps) |
|
|
|
dX = dX.view(*shape) |
|
return dX, dW.sum(0), None |
|
|
|
def fast_rms_layernorm(X, W, eps): |
|
out = Fast_RMS_Layernorm.apply(X, W, eps) |
|
return out |
|
|