|
import torch |
|
import math |
|
import triton |
|
import triton.language as tl |
|
|
|
from torch.cuda.amp import custom_bwd, custom_fwd |
|
|
|
def to_tl_dtype(input): |
|
if input == torch.float32: |
|
return tl.float32 |
|
elif input == torch.float16: |
|
return tl.float16 |
|
elif input == torch.bfloat16: |
|
return tl.bfloat16 |
|
elif input == torch.int64: |
|
return tl.int64 |
|
else: |
|
raise ValueError(f"Unable to convert the given input: '{input}'.") |
|
|
|
|
|
|
|
_kAlpha = math.sqrt(2.0 / math.pi) |
|
|
|
def gelu_torch(x): |
|
""" |
|
GeLU_ activation - Gaussian error linear unit |
|
|
|
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf |
|
""" |
|
return 0.5 * x * (1 + torch.tanh(_kAlpha * (x + 0.044715 * x * x * x))) |
|
|
|
def gelu_grad_torch(x): |
|
|
|
|
|
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) |
|
return 0.5 * x * ( |
|
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) |
|
) + 0.5 * (1 + tanh_out) |
|
|
|
|
|
@triton.jit |
|
def tanh(x): |
|
|
|
return 2 * tl.sigmoid(2 * x) - 1 |
|
|
|
@triton.jit |
|
def relu(x): |
|
""" |
|
ReLU_ activation function |
|
|
|
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html |
|
""" |
|
return tl.where(x >= 0, x, 0.0) |
|
|
|
|
|
@triton.jit |
|
def relu_grad(x): |
|
|
|
|
|
|
|
return tl.where(x >= 0, 1.0, 0.0) |
|
|
|
@triton.jit |
|
def gelu(x): |
|
""" |
|
GeLU_ activation - Gaussian error linear unit |
|
|
|
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf |
|
""" |
|
return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x))) |
|
|
|
|
|
@triton.jit |
|
def gelu_grad(x): |
|
|
|
|
|
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) |
|
return 0.5 * x * ( |
|
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) |
|
) + 0.5 * (1 + tanh_out) |
|
|
|
|
|
@triton.jit |
|
def gated_matmul_fwd( |
|
|
|
out, input, w1, w2, |
|
act_input_1, act_input_2, |
|
|
|
M, N, K, |
|
stride_om, |
|
stride_im, |
|
stride_wn, |
|
|
|
dtype: tl.constexpr, |
|
BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
|
USE_GELU: tl.constexpr, |
|
SAVE_ACTIVATION_INPUTS: tl.constexpr, |
|
IS_EVEN_MNK: tl.constexpr |
|
): |
|
|
|
""" |
|
Kernel for computing Out = activation(A x W + C) |
|
|
|
- Input has shape (M, K) |
|
- Weight 1 has shape (K, N) |
|
- Weight 2 has shape (K, N) |
|
- Output has shape (M, N) |
|
|
|
""" |
|
|
|
pid = tl.program_id(0) |
|
|
|
num_pid_m = tl.cdiv(M, BLOCK_M) |
|
num_pid_n = tl.cdiv(N, BLOCK_N) |
|
|
|
num_pid_in_group = GROUP_M * num_pid_n |
|
group_id = pid // num_pid_in_group |
|
first_pid_m = group_id * GROUP_M |
|
GROUP_M = min( |
|
num_pid_m - first_pid_m, GROUP_M |
|
) |
|
|
|
|
|
|
|
pid_m = first_pid_m + (pid % GROUP_M) |
|
pid_n = (pid % num_pid_in_group) // GROUP_M |
|
|
|
input_block_ptr = tl.make_block_ptr( |
|
base=input, |
|
shape=(M, K), |
|
strides=(stride_im, 1), |
|
offsets=(pid_m * BLOCK_M, 0), |
|
block_shape=(BLOCK_M, BLOCK_K), |
|
order=(1, 0), |
|
) |
|
|
|
w1_block_ptr = tl.make_block_ptr( |
|
base=w1, |
|
shape=(K, N), |
|
strides=(1, stride_wn), |
|
offsets=(0, pid_n * BLOCK_N), |
|
block_shape=(BLOCK_K, BLOCK_N), |
|
order=(0, 1), |
|
) |
|
|
|
w2_block_ptr = tl.make_block_ptr( |
|
base=w2, |
|
shape=(K, N), |
|
strides=(1, stride_wn), |
|
offsets=(0, pid_n * BLOCK_N), |
|
block_shape=(BLOCK_K, BLOCK_N), |
|
order=(0, 1), |
|
) |
|
|
|
|
|
acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
|
|
for i in range(0, K, BLOCK_K): |
|
|
|
if IS_EVEN_MNK: |
|
x = tl.load(input_block_ptr) |
|
w1_blk = tl.load(w1_block_ptr) |
|
w2_blk = tl.load(w2_block_ptr) |
|
else: |
|
x = tl.load(input_block_ptr, boundary_check=(0, 1)) |
|
w1_blk = tl.load(w1_block_ptr, boundary_check=(0, 1)) |
|
w2_blk = tl.load(w2_block_ptr, boundary_check=(0, 1)) |
|
|
|
acc1 += tl.dot(x, w1_blk) |
|
acc2 += tl.dot(x, w2_blk) |
|
|
|
input_block_ptr = tl.advance(input_block_ptr, (0, BLOCK_K)) |
|
w1_block_ptr = tl.advance(w1_block_ptr, (BLOCK_K, 0)) |
|
w2_block_ptr = tl.advance(w2_block_ptr, (BLOCK_K, 0)) |
|
|
|
if SAVE_ACTIVATION_INPUTS: |
|
act_in_1_ptrs = tl.make_block_ptr( |
|
base=act_input_1, |
|
shape=(M, N), |
|
strides=(stride_om, 1), |
|
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), |
|
block_shape=(BLOCK_M, BLOCK_N), |
|
order=(1, 0), |
|
) |
|
|
|
act_in_2_ptrs = tl.make_block_ptr( |
|
base=act_input_2, |
|
shape=(M, N), |
|
strides=(stride_om, 1), |
|
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), |
|
block_shape=(BLOCK_M, BLOCK_N), |
|
order=(1, 0), |
|
) |
|
|
|
if IS_EVEN_MNK: |
|
tl.store(act_in_1_ptrs, acc1.to(dtype)) |
|
tl.store(act_in_2_ptrs, acc2.to(dtype)) |
|
else: |
|
tl.store(act_in_1_ptrs, acc1.to(dtype), boundary_check=(0, 1)) |
|
tl.store(act_in_2_ptrs, acc2.to(dtype), boundary_check=(0, 1)) |
|
|
|
if USE_GELU: |
|
acc1 = gelu(acc1) |
|
else: |
|
acc1 = relu(acc1) |
|
|
|
|
|
acc = acc1 * acc2 |
|
|
|
|
|
out_ptrs = tl.make_block_ptr( |
|
base=out, |
|
shape=(M, N), |
|
strides=(stride_om, 1), |
|
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), |
|
block_shape=(BLOCK_M, BLOCK_N), |
|
order=(1, 0), |
|
) |
|
|
|
if IS_EVEN_MNK: |
|
tl.store(out_ptrs, acc.to(dtype)) |
|
else: |
|
tl.store(out_ptrs, acc.to(dtype), boundary_check=(0, 1)) |
|
|
|
@triton.jit |
|
def gated_matmul_bwd_ygrad( |
|
dout, |
|
y1_grad, y2_grad, |
|
act_input_1, act_input_2, |
|
M, N, |
|
stride_dom, |
|
|
|
dtype: tl.constexpr, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
USE_GELU: tl.constexpr, |
|
IS_EVEN_MNK: tl.constexpr): |
|
|
|
""" |
|
Kernel for backward gated MLP |
|
|
|
Ref : |
|
y2_grad = torch.mul(gelu(x @ w1), dout) |
|
y1_grad = torch.mul(gelu_grad(x @ w1) * (x @ w2), dout) |
|
""" |
|
|
|
pid_m = tl.program_id(0) |
|
pid_n = tl.program_id(1) |
|
|
|
|
|
actin_1_block_ptr = tl.make_block_ptr( |
|
base=act_input_1, |
|
shape=(M, N), |
|
strides=(stride_dom, 1), |
|
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), |
|
block_shape=(BLOCK_M, BLOCK_N), |
|
order=(1, 0), |
|
) |
|
|
|
actin_2_block_ptr = tl.make_block_ptr( |
|
base=act_input_2, |
|
shape=(M, N), |
|
strides=(stride_dom, 1), |
|
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), |
|
block_shape=(BLOCK_M, BLOCK_N), |
|
order=(1, 0), |
|
) |
|
|
|
dout_block_ptr = tl.make_block_ptr( |
|
base=dout, |
|
shape=(M, N), |
|
strides=(stride_dom, 1), |
|
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), |
|
block_shape=(BLOCK_M, BLOCK_N), |
|
order=(1, 0), |
|
) |
|
|
|
if IS_EVEN_MNK: |
|
dout_blk = tl.load(dout_block_ptr) |
|
actin_1_blk = tl.load(actin_1_block_ptr) |
|
actin_2_blk = tl.load(actin_2_block_ptr) |
|
else: |
|
dout_blk = tl.load(dout_block_ptr, boundary_check=(0, 1)) |
|
actin_1_blk = tl.load(actin_1_block_ptr, boundary_check=(0, 1)) |
|
actin_2_blk = tl.load(actin_2_block_ptr, boundary_check=(0, 1)) |
|
|
|
if USE_GELU: |
|
actin_act = gelu(actin_1_blk) |
|
actin_act_grad = gelu_grad(actin_1_blk) |
|
else: |
|
actin_act = relu(actin_1_blk) |
|
actin_act_grad = relu_grad(actin_1_blk) |
|
|
|
actin_act *= dout_blk |
|
actin_act_grad *= actin_2_blk |
|
actin_act_grad *= dout_blk |
|
|
|
y1_grad_ptrs = tl.make_block_ptr( |
|
base=y1_grad, |
|
shape=(M, N), |
|
strides=(stride_dom, 1), |
|
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), |
|
block_shape=(BLOCK_M, BLOCK_N), |
|
order=(1, 0), |
|
) |
|
|
|
y2_grad_ptrs = tl.make_block_ptr( |
|
base=y2_grad, |
|
shape=(M, N), |
|
strides=(stride_dom, 1), |
|
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), |
|
block_shape=(BLOCK_M, BLOCK_N), |
|
order=(1, 0), |
|
) |
|
|
|
if IS_EVEN_MNK: |
|
tl.store(y1_grad_ptrs, actin_act_grad.to(dtype)) |
|
tl.store(y2_grad_ptrs, actin_act.to(dtype)) |
|
else: |
|
tl.store(y1_grad_ptrs, actin_act_grad.to(dtype), boundary_check=(0, 1)) |
|
tl.store(y2_grad_ptrs, actin_act.to(dtype), boundary_check=(0, 1)) |
|
|
|
|
|
@triton.jit |
|
def gated_matmul_bwd_input( |
|
|
|
w1, w2, |
|
y1_grad, y2_grad, |
|
din, |
|
|
|
M, N, K, |
|
stride_dom, stride_im, |
|
stride_wn, |
|
|
|
dtype: tl.constexpr, |
|
BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
|
IS_EVEN_MNK: tl.constexpr |
|
): |
|
|
|
""" |
|
Kernel for backward gated MLP |
|
We group along the N axis |
|
|
|
Ref : |
|
x_grad = torch.matmul(y2_grad, w2.t()) + torch.matmul(y1_grad, w1.t()) |
|
""" |
|
|
|
pid = tl.program_id(0) |
|
|
|
num_pid_m = tl.cdiv(M, BLOCK_M) |
|
num_pid_k = tl.cdiv(K, BLOCK_K) |
|
|
|
num_pid_in_group = GROUP_M * num_pid_k |
|
group_id = pid // num_pid_in_group |
|
first_pid_m = group_id * GROUP_M |
|
GROUP_M = min( |
|
num_pid_m - first_pid_m, GROUP_M |
|
) |
|
|
|
|
|
|
|
pid_m = first_pid_m + (pid % GROUP_M) |
|
pid_k = (pid % num_pid_in_group) // GROUP_M |
|
|
|
y1_grad_block_ptr = tl.make_block_ptr( |
|
base=y1_grad, |
|
shape=(M, N), |
|
strides=(stride_dom, 1), |
|
offsets=(pid_m * BLOCK_M, 0), |
|
block_shape=(BLOCK_M, BLOCK_N), |
|
order=(1, 0), |
|
) |
|
|
|
y2_grad_block_ptr = tl.make_block_ptr( |
|
base=y2_grad, |
|
shape=(M, N), |
|
strides=(stride_dom, 1), |
|
offsets=(pid_m * BLOCK_M, 0), |
|
block_shape=(BLOCK_M, BLOCK_N), |
|
order=(1, 0), |
|
) |
|
|
|
w1_block_ptr = tl.make_block_ptr( |
|
base=w1, |
|
shape=(N, K), |
|
strides=(stride_wn, 1), |
|
offsets=(0, pid_k * BLOCK_K), |
|
block_shape=(BLOCK_N, BLOCK_K), |
|
order=(1, 0), |
|
) |
|
|
|
w2_block_ptr = tl.make_block_ptr( |
|
base=w2, |
|
shape=(N, K), |
|
strides=(stride_wn, 1), |
|
offsets=(0, pid_k * BLOCK_K), |
|
block_shape=(BLOCK_N, BLOCK_K), |
|
order=(1, 0), |
|
) |
|
|
|
|
|
acc_dx = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) |
|
|
|
for i in range(0, N, BLOCK_N): |
|
|
|
if IS_EVEN_MNK: |
|
w1_blk = tl.load(w1_block_ptr) |
|
w2_blk = tl.load(w2_block_ptr) |
|
y1_grad_blk = tl.load(y1_grad_block_ptr) |
|
y2_grad_blk = tl.load(y2_grad_block_ptr) |
|
else: |
|
w1_blk = tl.load(w1_block_ptr, boundary_check=(0, 1)) |
|
w2_blk = tl.load(w2_block_ptr, boundary_check=(0, 1)) |
|
y1_grad_blk = tl.load(y1_grad_block_ptr, boundary_check=(0, 1)) |
|
y2_grad_blk = tl.load(y2_grad_block_ptr, boundary_check=(0, 1)) |
|
|
|
acc_dx += tl.dot(y2_grad_blk, w2_blk) |
|
acc_dx += tl.dot(y1_grad_blk, w1_blk) |
|
|
|
w1_block_ptr = tl.advance(w1_block_ptr, (BLOCK_N, 0)) |
|
w2_block_ptr = tl.advance(w2_block_ptr, (BLOCK_N, 0)) |
|
y1_grad_block_ptr = tl.advance(y1_grad_block_ptr, (0, BLOCK_N)) |
|
y2_grad_block_ptr = tl.advance(y2_grad_block_ptr, (0, BLOCK_N)) |
|
|
|
|
|
dx_ptrs = tl.make_block_ptr( |
|
base=din, |
|
shape=(M, K), |
|
strides=(stride_im, 1), |
|
offsets=(pid_m * BLOCK_M, pid_k * BLOCK_K), |
|
block_shape=(BLOCK_M, BLOCK_K), |
|
order=(1, 0), |
|
) |
|
|
|
if IS_EVEN_MNK: |
|
tl.store(dx_ptrs, acc_dx.to(dtype)) |
|
else: |
|
tl.store(dx_ptrs, acc_dx.to(dtype), boundary_check=(0, 1)) |
|
|
|
|
|
@triton.jit |
|
def gated_matmul_bwd_weights( |
|
|
|
input, |
|
y1_grad, y2_grad, |
|
dw1, dw2, |
|
|
|
M, N, K, |
|
stride_dom, stride_im, |
|
stride_wn, |
|
|
|
dtype: tl.constexpr, |
|
BLOCK_M: tl.constexpr, GROUP_N: tl.constexpr, |
|
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
|
IS_EVEN_MNK: tl.constexpr |
|
): |
|
|
|
""" |
|
Kernel for backward gated MLP |
|
We group along the M axis |
|
|
|
Ref : |
|
w1_grad = torch.matmul(y1_grad.t(), x) |
|
w2_grad = torch.matmul(y2_grad.t(), x) |
|
""" |
|
|
|
pid = tl.program_id(0) |
|
|
|
num_pid_n = tl.cdiv(N, BLOCK_N) |
|
num_pid_k = tl.cdiv(K, BLOCK_K) |
|
|
|
num_pid_in_group = GROUP_N * num_pid_k |
|
group_id = pid // num_pid_in_group |
|
first_pid_n = group_id * GROUP_N |
|
GROUP_N = min( |
|
num_pid_n - first_pid_n, GROUP_N |
|
) |
|
|
|
|
|
|
|
pid_n = first_pid_n + (pid % GROUP_N) |
|
pid_k = (pid % num_pid_in_group) // GROUP_N |
|
|
|
|
|
y1_grad_block_ptr = tl.make_block_ptr( |
|
base=y1_grad, |
|
shape=(N, M), |
|
strides=(1, stride_dom), |
|
offsets=(pid_n * BLOCK_N, 0), |
|
block_shape=(BLOCK_N, BLOCK_M), |
|
order=(0, 1), |
|
) |
|
|
|
y2_grad_block_ptr = tl.make_block_ptr( |
|
base=y2_grad, |
|
shape=(N, M), |
|
strides=(1, stride_dom), |
|
offsets=(pid_n * BLOCK_N, 0), |
|
block_shape=(BLOCK_N, BLOCK_M), |
|
order=(0, 1), |
|
) |
|
|
|
input_block_ptr = tl.make_block_ptr( |
|
base=input, |
|
shape=(M, K), |
|
strides=(stride_im, 1), |
|
offsets=(0, pid_k * BLOCK_K), |
|
block_shape=(BLOCK_M, BLOCK_K), |
|
order=(1, 0), |
|
) |
|
|
|
ref = tl.load(input + tl.arange(0, 1)) |
|
|
|
|
|
acc_dw1 = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32) |
|
acc_dw2 = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32) |
|
|
|
for i in range(0, M, BLOCK_M): |
|
|
|
if IS_EVEN_MNK: |
|
y1grad_blk = tl.load(y1_grad_block_ptr) |
|
y2grad_blk = tl.load(y2_grad_block_ptr) |
|
x = tl.load(input_block_ptr) |
|
else: |
|
y1grad_blk = tl.load(y1_grad_block_ptr, boundary_check=(0, 1)) |
|
y2grad_blk = tl.load(y2_grad_block_ptr, boundary_check=(0, 1)) |
|
x = tl.load(input_block_ptr, boundary_check=(0, 1)) |
|
|
|
acc_dw1 += tl.dot(y1grad_blk, x) |
|
acc_dw2 += tl.dot(y2grad_blk, x) |
|
|
|
y1_grad_block_ptr = tl.advance(y1_grad_block_ptr, (0, BLOCK_M)) |
|
y2_grad_block_ptr = tl.advance(y2_grad_block_ptr, (0, BLOCK_M)) |
|
input_block_ptr = tl.advance(input_block_ptr, (BLOCK_M, 0)) |
|
|
|
|
|
dw1_ptrs = tl.make_block_ptr( |
|
base=dw1, |
|
shape=(N, K), |
|
strides=(stride_wn, 1), |
|
offsets=(pid_n * BLOCK_N, pid_k * BLOCK_K), |
|
block_shape=(BLOCK_N, BLOCK_K), |
|
order=(1, 0), |
|
) |
|
|
|
dw2_ptrs = tl.make_block_ptr( |
|
base=dw2, |
|
shape=(N, K), |
|
strides=(stride_wn, 1), |
|
offsets=(pid_n * BLOCK_N, pid_k * BLOCK_K), |
|
block_shape=(BLOCK_N, BLOCK_K), |
|
order=(1, 0), |
|
) |
|
|
|
if IS_EVEN_MNK: |
|
tl.store(dw1_ptrs, acc_dw1.to(dtype)) |
|
tl.store(dw2_ptrs, acc_dw2.to(dtype)) |
|
else: |
|
tl.store(dw1_ptrs, acc_dw1.to(dtype), boundary_check=(0, 1)) |
|
tl.store(dw2_ptrs, acc_dw2.to(dtype), boundary_check=(0, 1)) |
|
|
|
|
|
class GatedMLP(torch.autograd.Function): |
|
@staticmethod |
|
@custom_fwd |
|
def forward(ctx, x, w1, w2, use_gelu=True): |
|
|
|
BLOCK_M = 128 |
|
BLOCK_N = 64 |
|
BLOCK_K = 64 |
|
GROUP_M = 8 |
|
|
|
SAVE_ACT_IN = x.requires_grad |
|
|
|
if torch.is_autocast_enabled(): |
|
x = x.to(torch.get_autocast_gpu_dtype()) |
|
w1 = w1.to(torch.get_autocast_gpu_dtype()) |
|
w2 = w2.to(torch.get_autocast_gpu_dtype()) |
|
|
|
assert x.is_contiguous() |
|
assert w1.is_contiguous() |
|
assert w2.is_contiguous() |
|
assert w1.shape == w2.shape |
|
assert x.shape[2] == w1.shape[1] |
|
assert x.shape[2] == w2.shape[1] |
|
|
|
x_ = x if x.ndim == 2 else x.flatten(0, -2) |
|
|
|
M, K = x_.shape |
|
N, K = w1.shape |
|
|
|
IS_EVEN_MNK = ((M % BLOCK_M) == 0) and ((N % BLOCK_N) == 0) and ((K % BLOCK_K) == 0) |
|
|
|
out = torch.empty((M, N), device=x.device, dtype=x.dtype) |
|
|
|
tl_dtype = to_tl_dtype(x.dtype) |
|
|
|
act_input_1, act_input_2 = None, None |
|
if SAVE_ACT_IN: |
|
act_input_1 = torch.empty_like(out) |
|
act_input_2 = torch.empty_like(out) |
|
|
|
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) |
|
gated_matmul_fwd[grid]( |
|
out, |
|
x_, w1, w2, |
|
act_input_1, act_input_2, |
|
M, N, K, |
|
out.stride(0), x_.stride(0), |
|
w1.stride(0), |
|
tl_dtype, |
|
BLOCK_M, GROUP_M, BLOCK_N, BLOCK_K, |
|
use_gelu, |
|
SAVE_ACT_IN, |
|
IS_EVEN_MNK, |
|
) |
|
|
|
ctx.save_for_backward(x_, w1, w2, act_input_1, act_input_2) |
|
ctx.use_gelu = use_gelu |
|
ctx.is_even_nmk = IS_EVEN_MNK |
|
ctx.x_shape = x.shape |
|
|
|
out = out if x.ndim == 2 else out.reshape(*x.shape[:-1], N) |
|
|
|
return out |
|
|
|
@staticmethod |
|
@custom_bwd |
|
def backward(ctx, dout): |
|
BLOCK_M = 64 |
|
BLOCK_N = 64 |
|
BLOCK_K = 64 |
|
GROUP_M = 8 |
|
|
|
x_, w1, w2, act_input_1, act_input_2 = ctx.saved_tensors |
|
|
|
M, K = x_.shape |
|
N, K = w1.shape |
|
|
|
tl_dtype = to_tl_dtype(x_.dtype) |
|
|
|
''' |
|
din = torch.empty_like(x_) |
|
dw1 = torch.empty_like(w1) |
|
dw2 = torch.empty_like(w2) |
|
|
|
dout_ = dout if dout.ndim == 2 else dout.flatten(0, -2) |
|
|
|
y1_grad = torch.empty_like(dout_) |
|
y2_grad = torch.empty_like(dout_) |
|
|
|
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) |
|
gated_matmul_bwd_ygrad[grid]( |
|
dout_, |
|
y1_grad, y2_grad, |
|
act_input_1, act_input_2, |
|
M, N, |
|
dout_.stride(0), |
|
# Meta-parameters |
|
tl_dtype, |
|
BLOCK_M, BLOCK_N, |
|
ctx.use_gelu, |
|
ctx.is_even_nmk) |
|
|
|
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(K, BLOCK_K),) |
|
gated_matmul_bwd_input[grid]( |
|
w1, w2, |
|
y1_grad, y2_grad, |
|
din, |
|
M, N, K, |
|
dout_.stride(0), x_.stride(0), |
|
w1.stride(0), |
|
tl_dtype, |
|
BLOCK_M, GROUP_M, |
|
BLOCK_N, BLOCK_K, |
|
ctx.is_even_nmk) |
|
|
|
# reorder sizes |
|
BLOCK_M = 64 |
|
BLOCK_N = 64 |
|
grid = (triton.cdiv(N, BLOCK_N) * triton.cdiv(K, BLOCK_K),) |
|
gated_matmul_bwd_weights[grid]( |
|
x_, |
|
y1_grad, y2_grad, |
|
dw1, dw2, |
|
M, N, K, |
|
y1_grad.stride(0), x_.stride(0), |
|
dw1.stride(0), |
|
tl_dtype, |
|
BLOCK_M, GROUP_M, |
|
BLOCK_N, BLOCK_K, |
|
ctx.is_even_nmk) |
|
|
|
din = din if len(ctx.x_shape) == 2 else din.reshape(ctx.x_shape) |
|
''' |
|
|
|
dout_ = dout if dout.ndim == 2 else dout.flatten(0, -2) |
|
|
|
y1_grad = torch.empty_like(dout_) |
|
y2_grad = torch.empty_like(dout_) |
|
|
|
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) |
|
gated_matmul_bwd_ygrad[grid]( |
|
dout_, |
|
y1_grad, y2_grad, |
|
act_input_1, act_input_2, |
|
M, N, |
|
dout_.stride(0), |
|
|
|
tl_dtype, |
|
BLOCK_M, BLOCK_N, |
|
ctx.use_gelu, |
|
ctx.is_even_nmk) |
|
|
|
|
|
|
|
|
|
din = torch.matmul(y2_grad, w2) + torch.matmul(y1_grad, w1) |
|
dw1 = torch.matmul(y1_grad.t(), x_) |
|
dw2 = torch.matmul(y2_grad.t(), x_) |
|
|
|
din = din if len(ctx.x_shape) == 2 else din.reshape(ctx.x_shape) |
|
|
|
return din, dw1, dw2, None |
|
|
|
gated_mlp = GatedMLP.apply |
|
|