|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import triton |
|
import triton.language as tl |
|
import torch |
|
|
|
MAX_FUSED_SIZE = 65536 |
|
next_power_of_2 = triton.next_power_of_2 |
|
|
|
def calculate_settings(n): |
|
BLOCK_SIZE = next_power_of_2(n) |
|
if BLOCK_SIZE > MAX_FUSED_SIZE: |
|
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\ |
|
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") |
|
num_warps = 4 |
|
if BLOCK_SIZE >= 32768: num_warps = 32 |
|
elif BLOCK_SIZE >= 8192: num_warps = 16 |
|
elif BLOCK_SIZE >= 2048: num_warps = 8 |
|
return BLOCK_SIZE, num_warps |
|
|
|
@triton.jit |
|
def _cross_entropy_forward(logits_ptr, logits_row_stride, |
|
loss_ptr, |
|
lse_ptr, |
|
labels_ptr, |
|
n_cols, |
|
BLOCK_SIZE: tl.constexpr, |
|
IS_EVEN: tl.constexpr): |
|
""" |
|
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] |
|
Pi = exp(xi) / sum(exp(xi)) |
|
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ] |
|
= -y [ x - log[sum(exp(x))] ] |
|
= y * (log[sum(exp(x))] - x) |
|
If y == 0: CE_i = 0 |
|
If y == 1: CE_i = logsumexp - x |
|
""" |
|
row_idx = tl.program_id(0) |
|
logits_ptr += row_idx * logits_row_stride |
|
loss_ptr += row_idx |
|
lse_ptr += row_idx |
|
labels_ptr += row_idx |
|
|
|
col_offsets = tl.arange(0, BLOCK_SIZE) |
|
mask = col_offsets < n_cols |
|
|
|
|
|
label_idx = tl.load(labels_ptr).to(tl.int32) |
|
if IS_EVEN: |
|
logits = tl.load(logits_ptr + col_offsets).to(tl.float32) |
|
else: |
|
logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) |
|
|
|
max_logits = tl.max(logits, 0) |
|
|
|
|
|
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits |
|
tl.store(lse_ptr, lse) |
|
|
|
if label_idx != -100: |
|
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32) |
|
loss = lse - logits_label |
|
else: |
|
loss = 0.0 |
|
|
|
tl.store(loss_ptr, loss) |
|
|
|
@triton.jit |
|
def _cross_entropy_backward(logits_ptr, logits_row_stride, |
|
dinputs_ptr, dinputs_row_stride, |
|
dloss_ptr, dloss_row_stride, |
|
dzloss_ptr, dzloss_row_stride, |
|
lse_ptr, |
|
labels_ptr, |
|
n_cols, |
|
BLOCK_SIZE: tl.constexpr, |
|
USE_Z_LOSS: tl.constexpr, |
|
IS_EVEN: tl.constexpr): |
|
""" |
|
CE_i = -y log(P) = y * (log[sum(exp(x))] - x) |
|
dC/dx = d/dx (y * log[sum(exp(x))] - x * y) |
|
|
|
From https://en.wikipedia.org/wiki/LogSumExp |
|
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x) |
|
|
|
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y) |
|
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick |
|
dC/dx = y * exp[x - logsumexp] - d/dx (x * y) |
|
|
|
If y == 0: dC/dx = 0 |
|
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1 |
|
If y == 1 and x != label: dC/dx = exp[x - logsumexp] |
|
""" |
|
|
|
row_idx = tl.program_id(0) |
|
|
|
logits_ptr += row_idx * logits_row_stride |
|
dinputs_ptr += row_idx * dinputs_row_stride |
|
dloss_ptr += row_idx * dloss_row_stride |
|
dzloss_ptr += row_idx * dzloss_row_stride |
|
col_offsets = tl.arange(0, BLOCK_SIZE) |
|
mask = col_offsets < n_cols |
|
|
|
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32) |
|
|
|
if label_idx != -100: |
|
dloss = tl.load(dloss_ptr) |
|
dzloss = tl.load(dzloss_ptr) |
|
else: |
|
dloss = 0.0 |
|
dzloss = 0.0 |
|
|
|
if IS_EVEN: |
|
logits = tl.load(logits_ptr + col_offsets).to(tl.float32) |
|
else: |
|
logits = tl.load(logits_ptr + col_offsets, mask=mask, other=0).to(tl.float32) |
|
|
|
lse = tl.load(lse_ptr + row_idx) |
|
probs = tl.exp(logits - lse) |
|
|
|
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) |
|
din = dloss * probs |
|
|
|
|
|
if USE_Z_LOSS: |
|
if label_idx != -100: |
|
dzloss = tl.load(dzloss_ptr) |
|
else: |
|
dzloss = 0.0 |
|
|
|
row_minus_max = logits |
|
numerator = tl.exp(row_minus_max) |
|
denominator = tl.sum(numerator, axis=0) |
|
softmax_output = numerator / denominator |
|
din += softmax_output * dzloss |
|
|
|
if IS_EVEN: |
|
tl.store(dinputs_ptr + col_offsets, din) |
|
else: |
|
tl.store(dinputs_ptr + col_offsets, din, mask=mask) |
|
|
|
|
|
|
|
torch.library.define("flasht5::cross_entropy_triton_fwd", "(Tensor logits, Tensor labels, int n_cols, int n_rows, int BLOCK_SIZE, int num_warps) -> (Tensor, Tensor)") |
|
|
|
@torch.library.impl("flasht5::cross_entropy_triton_fwd", "default") |
|
def cross_entropy_triton_fwd(logits, labels, n_cols, n_rows, BLOCK_SIZE, num_warps): |
|
losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) |
|
logsumexp = torch.empty(n_rows, dtype=torch.float32, device=logits.device) |
|
|
|
_cross_entropy_forward[(n_rows,)]( |
|
logits, logits.stride(0), |
|
losses, |
|
logsumexp, |
|
labels, |
|
n_cols, |
|
BLOCK_SIZE = BLOCK_SIZE, |
|
IS_EVEN=((n_cols % BLOCK_SIZE) == 0), |
|
num_warps = num_warps, |
|
) |
|
|
|
return losses, logsumexp |
|
|
|
|
|
@torch.library.register_fake("flasht5::cross_entropy_triton_fwd", cross_entropy_triton_fwd) |
|
def cross_entropy_triton_fwd_abstract(logits, labels, n_cols, n_rows, BLOCK_SIZE, num_warps): |
|
losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) |
|
logsumexp = torch.empty(n_rows, dtype=torch.float32, device=logits.device) |
|
|
|
return losses, logsumexp |
|
|
|
torch.library.define("flasht5::cross_entropy_triton_bwd", "(Tensor dlosses, Tensor dlogsumexp, Tensor logits, Tensor logsumexp, Tensor labels, float z_loss_factor, int n_cols, int n_rows, int BLOCK_SIZE, int num_warps) -> Tensor") |
|
|
|
@torch.library.impl("flasht5::cross_entropy_triton_bwd", "default") |
|
def cross_entropy_triton_bwd(dlosses, dlogsumexp, logits, logsumexp, labels, z_loss_factor, n_cols, n_rows, BLOCK_SIZE, num_warps): |
|
|
|
dinputs = torch.empty_like(logits) |
|
|
|
_cross_entropy_backward[(n_rows,)]( |
|
logits, logits.stride(0), |
|
dinputs, dinputs.stride(0), |
|
dlosses, dlosses.stride(0), |
|
dlogsumexp, dlogsumexp.stride(0), |
|
logsumexp, |
|
labels, |
|
n_cols, |
|
BLOCK_SIZE = BLOCK_SIZE, |
|
USE_Z_LOSS = (z_loss_factor != 0.0), |
|
IS_EVEN=((n_cols % BLOCK_SIZE) == 0), |
|
num_warps = num_warps, |
|
) |
|
|
|
return dinputs |
|
|
|
|
|
@torch.library.register_fake("flasht5::cross_entropy_triton_bwd", cross_entropy_triton_bwd) |
|
def cross_entropy_triton_bwd_abstract(dlosses, dlogsumexp, logits, logsumexp, labels, z_loss_factor, n_cols, n_rows, BLOCK_SIZE, num_warps): |
|
return torch.empty_like(logits) |
|
|
|
class Fast_CrossEntropyLoss(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, logits, labels, z_loss_factor): |
|
n_rows, n_cols = logits.shape |
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols) |
|
|
|
losses, logsumexp = torch.ops.flasht5.cross_entropy_triton_fwd( |
|
logits, |
|
labels, |
|
n_cols, |
|
n_rows, |
|
BLOCK_SIZE = BLOCK_SIZE, |
|
num_warps = num_warps |
|
) |
|
|
|
ctx.BLOCK_SIZE = BLOCK_SIZE |
|
ctx.num_warps = num_warps |
|
ctx.z_loss_factor = z_loss_factor |
|
ctx.save_for_backward(logits, logsumexp, labels) |
|
return losses, logsumexp |
|
|
|
@staticmethod |
|
def backward(ctx, dlosses, dlogsumexp): |
|
logits, logsumexp, labels = ctx.saved_tensors |
|
n_rows, n_cols = logits.shape |
|
|
|
dinputs = torch.ops.flasht5.cross_entropy_triton_bwd( |
|
dlosses, |
|
dlogsumexp, |
|
logits, |
|
logsumexp, |
|
labels, |
|
ctx.z_loss_factor, |
|
n_cols, |
|
n_rows, |
|
ctx.BLOCK_SIZE, |
|
ctx.num_warps |
|
) |
|
return dinputs, None, None |
|
|
|
def fast_cross_entropy_loss(logits, labels, z_loss_factor=0.0): |
|
""" |
|
Arguments: |
|
logits: (batch, seq_len, vocab_size) |
|
labels: (batch, seq_len,) |
|
Returns: |
|
losses: float |
|
""" |
|
batch, seq_len, d = logits.shape |
|
assert(labels.shape == (batch, seq_len)) |
|
assert (d <= MAX_FUSED_SIZE) |
|
|
|
loss, lse = Fast_CrossEntropyLoss.apply( |
|
logits.view(batch*seq_len, d), |
|
labels.view(-1), |
|
z_loss_factor |
|
) |
|
|
|
n_items = torch.count_nonzero(labels != -100) |
|
|
|
return loss.sum() / n_items, (z_loss_factor * torch.square(lse).sum()) / n_items |
|
|