from typing import Any, Callable, Dict, Hashable, Tuple import torch import triton import triton.language as tl from triton.compiler import CompiledKernel from triton.runtime import JITFunction try: import triton.language.math as tlmath # Triton 2.1 except ImportError: import triton.language.libdevice as tlmath # Triton 2.0 class TritonKernel: def __init__( self, kernel_fn: JITFunction, grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]], ) -> None: self.kernel_fn_ = kernel_fn self.grid_fn_ = grid_fn self.kernel_cache_: Dict[Hashable, CompiledKernel] = {} def run(self, *args, **kwargs): # Set current device input_device = args[0].device prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device() if input_device.index != cur_dev_idx: prev_dev_idx = cur_dev_idx torch.cuda.set_device(input_device.index) # Compute grid grid = self.grid_fn_(args) # Use cached kernel if possible kernel_key = (input_device,) + tuple(kwargs.items()) if kernel_key in self.kernel_cache_: kernel = self.kernel_cache_[kernel_key] kernel[grid](*args) else: # Compile and store new kernel kernel = self.kernel_fn_[grid](*args, **kwargs) self.kernel_cache_[kernel_key] = kernel # Restore previous device torch.cuda.set_device(prev_dev_idx) @triton.jit def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr): batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2) seq_len, num_heads = tl.num_programs(1), tl.num_programs(2) block_idx = tl.arange(0, HEAD_DIM) x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM x = tl.load(X + x_base_idx + block_idx) freq_idx = tok_idx * HEAD_DIM + block_idx cos = tl.load(Cos + freq_idx) rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM x_rot = tl.load(X + x_base_idx + rot_idx) x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot) sin = tl.load(Sin + freq_idx) y_idx = ( (batch_idx * seq_len + tok_idx) * num_heads + head_idx ) * HEAD_DIM + block_idx y = x * cos + x_rot * sin tl.store(Y + y_idx, y.to(Y.dtype.element_ty)) apply_rope_fwd_kernel = TritonKernel( _apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3]) ) def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): y = torch.empty(x.shape, dtype=x.dtype, device=x.device) apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1)) return y @triton.jit def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr): tok_idx = tl.program_id(0) mean_sq = tl.zeros([BLOCK_SIZE], tl.float32) for offset in range(0, hidden_dim, BLOCK_SIZE): dim_idx = offset + tl.arange(0, BLOCK_SIZE) x = tl.load( X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0 ).to(tl.float32) mean_sq += x * x / hidden_dim rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps) for offset in range(0, hidden_dim, BLOCK_SIZE): dim_idx = offset + tl.arange(0, BLOCK_SIZE) dim_mask = dim_idx < hidden_dim hidden_idx = tok_idx * hidden_dim + dim_idx x = tl.load(X + hidden_idx, mask=dim_mask, other=0) w = tl.load(W + dim_idx, mask=dim_mask, other=0) y = x * rrms * w tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask) rms_norm_fwd_kernel = TritonKernel( _rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1) ) def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float): y = torch.empty_like(x) hidden_dim = x.size(-1) rms_norm_fwd_kernel.run( x, weight, y, eps, hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim) ) return y