Shangming Cai commited on
Commit
af64202
1 Parent(s): f2191b9

Add ApplyRoPE and RMSNorm kernels written in OpenAI Triton.

Browse files
Files changed (4) hide show
  1. config.json +1 -0
  2. configuration_qwen.py +2 -0
  3. modeling_qwen.py +33 -3
  4. triton_kernels.py +115 -0
config.json CHANGED
@@ -44,6 +44,7 @@
44
  "use_cache": true,
45
  "use_dynamic_ntk": true,
46
  "use_flash_attn": "auto",
 
47
  "use_logn_attn": true,
48
  "vocab_size": 151936
49
  }
 
44
  "use_cache": true,
45
  "use_dynamic_ntk": true,
46
  "use_flash_attn": "auto",
47
+ "use_triton": "auto",
48
  "use_logn_attn": true,
49
  "vocab_size": 151936
50
  }
configuration_qwen.py CHANGED
@@ -32,6 +32,7 @@ class QWenConfig(PretrainedConfig):
32
  use_dynamic_ntk=True,
33
  use_logn_attn=True,
34
  use_flash_attn="auto",
 
35
  intermediate_size=22016,
36
  no_bias=True,
37
  tie_word_embeddings=False,
@@ -61,6 +62,7 @@ class QWenConfig(PretrainedConfig):
61
  self.use_dynamic_ntk = use_dynamic_ntk
62
  self.use_logn_attn = use_logn_attn
63
  self.use_flash_attn = use_flash_attn
 
64
  self.no_bias = no_bias
65
  self.use_cache_quantization = use_cache_quantization
66
  self.use_cache_kernel = use_cache_kernel
 
32
  use_dynamic_ntk=True,
33
  use_logn_attn=True,
34
  use_flash_attn="auto",
35
+ use_triton="auto",
36
  intermediate_size=22016,
37
  no_bias=True,
38
  tie_word_embeddings=False,
 
62
  self.use_dynamic_ntk = use_dynamic_ntk
63
  self.use_logn_attn = use_logn_attn
64
  self.use_flash_attn = use_flash_attn
65
+ self.use_triton = use_triton
66
  self.no_bias = no_bias
67
  self.use_cache_quantization = use_cache_quantization
68
  self.use_cache_kernel = use_cache_kernel
modeling_qwen.py CHANGED
@@ -35,7 +35,7 @@ except ImportError:
35
  from torch import nn
36
 
37
  SUPPORT_CUDA = torch.cuda.is_available()
38
- SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
39
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
40
  SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
41
 
@@ -76,7 +76,9 @@ We detect you have activated flash attention support, but running model computat
76
  """
77
 
78
  apply_rotary_emb_func = None
 
79
  rms_norm = None
 
80
  flash_attn_unpadded_func = None
81
  flash_attn_func = None
82
 
@@ -120,6 +122,24 @@ def _import_flash_attn():
120
  "https://github.com/Dao-AILab/flash-attention"
121
  )
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def quantize_cache_v(fdata, bits, qmax, qmin):
124
  # b, s, head, h-dim->b, head, s, h-dim
125
  qtype = torch.uint8
@@ -978,6 +998,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
978
  if config.use_flash_attn:
979
  _import_flash_attn()
980
 
 
 
 
 
 
 
981
  self.transformer = QWenModel(config)
982
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
983
 
@@ -1335,7 +1361,9 @@ def apply_rotary_pos_emb(t, freqs):
1335
  rot_dim = freqs[0].shape[-1]
1336
  cos, sin = freqs
1337
  t_float = t.float()
1338
- if apply_rotary_emb_func is not None and t.is_cuda:
 
 
1339
  # apply_rotary_emb in flash_attn requires cos/sin to be of
1340
  # shape (seqlen, rotary_dim / 2) and apply rotary embedding
1341
  # to the first rotary_dim of the input
@@ -1358,7 +1386,9 @@ class RMSNorm(torch.nn.Module):
1358
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1359
 
1360
  def forward(self, x):
1361
- if rms_norm is not None and x.is_cuda:
 
 
1362
  return rms_norm(x, self.weight, self.eps)
1363
  else:
1364
  output = self._norm(x.float()).type_as(x)
 
35
  from torch import nn
36
 
37
  SUPPORT_CUDA = torch.cuda.is_available()
38
+ SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 8
39
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
40
  SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
41
 
 
76
  """
77
 
78
  apply_rotary_emb_func = None
79
+ apply_rotary_emb_func_triton = None
80
  rms_norm = None
81
+ rms_norm_triton = None
82
  flash_attn_unpadded_func = None
83
  flash_attn_func = None
84
 
 
122
  "https://github.com/Dao-AILab/flash-attention"
123
  )
124
 
125
+ def _import_triton():
126
+ global apply_rotary_emb_func_triton, rms_norm_triton
127
+ try:
128
+ from .triton_kernels import apply_rotary_emb as __apply_rotary_emb, rms_norm as __rms_norm
129
+ if apply_rotary_emb_func is not None:
130
+ logger.warn(
131
+ "Using Triton rotary kernel instead of flash_attn for inference."
132
+ )
133
+ apply_rotary_emb_func_triton = __apply_rotary_emb
134
+ if rms_norm is not None:
135
+ logger.warn(
136
+ "Using Triton rms_norm kernel instead of flash_attn for inference."
137
+ )
138
+ rms_norm_triton = __rms_norm
139
+ except ImportError:
140
+ logger.warn("Warning: Failed to import Triton kernels.")
141
+ return
142
+
143
  def quantize_cache_v(fdata, bits, qmax, qmin):
144
  # b, s, head, h-dim->b, head, s, h-dim
145
  qtype = torch.uint8
 
998
  if config.use_flash_attn:
999
  _import_flash_attn()
1000
 
1001
+ if config.use_triton == "auto":
1002
+ logger.warn("Try importing Triton kernels for faster inference...")
1003
+ config.use_triton = SUPPORT_TORCH2
1004
+ if config.use_triton:
1005
+ _import_triton()
1006
+
1007
  self.transformer = QWenModel(config)
1008
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1009
 
 
1361
  rot_dim = freqs[0].shape[-1]
1362
  cos, sin = freqs
1363
  t_float = t.float()
1364
+ if apply_rotary_emb_func_triton is not None and t.is_cuda and (not t.requires_grad):
1365
+ return apply_rotary_emb_func_triton(t, cos, sin)
1366
+ elif apply_rotary_emb_func is not None and t.is_cuda:
1367
  # apply_rotary_emb in flash_attn requires cos/sin to be of
1368
  # shape (seqlen, rotary_dim / 2) and apply rotary embedding
1369
  # to the first rotary_dim of the input
 
1386
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1387
 
1388
  def forward(self, x):
1389
+ if rms_norm_triton is not None and x.is_cuda and (not x.requires_grad):
1390
+ return rms_norm_triton(x, self.weight, self.eps)
1391
+ elif rms_norm is not None and x.is_cuda:
1392
  return rms_norm(x, self.weight, self.eps)
1393
  else:
1394
  output = self._norm(x.float()).type_as(x)
triton_kernels.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, Hashable, Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+ from triton.compiler import CompiledKernel
7
+ from triton.runtime import JITFunction
8
+
9
+ try:
10
+ import triton.language.math as tlmath # Triton 2.1
11
+ except ImportError:
12
+ import triton.language.libdevice as tlmath # Triton 2.0
13
+
14
+
15
+ class TritonKernel:
16
+ def __init__(
17
+ self,
18
+ kernel_fn: JITFunction,
19
+ grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]],
20
+ ) -> None:
21
+ self.kernel_fn_ = kernel_fn
22
+ self.grid_fn_ = grid_fn
23
+ self.kernel_cache_: Dict[Hashable, CompiledKernel] = {}
24
+
25
+ def run(self, *args, **kwargs):
26
+ # Set current device
27
+ input_device = args[0].device
28
+ prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device()
29
+ if input_device.index != cur_dev_idx:
30
+ prev_dev_idx = cur_dev_idx
31
+ torch.cuda.set_device(input_device.index)
32
+
33
+ # Compute grid
34
+ grid = self.grid_fn_(args)
35
+
36
+ # Use cached kernel if possible
37
+ kernel_key = (input_device,) + tuple(kwargs.items())
38
+ if kernel_key in self.kernel_cache_:
39
+ kernel = self.kernel_cache_[kernel_key]
40
+ kernel[grid](*args)
41
+ else:
42
+ # Compile and store new kernel
43
+ kernel = self.kernel_fn_[grid](*args, **kwargs)
44
+ self.kernel_cache_[kernel_key] = kernel
45
+
46
+ # Restore previous device
47
+ torch.cuda.set_device(prev_dev_idx)
48
+
49
+
50
+ @triton.jit
51
+ def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr):
52
+ batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ seq_len, num_heads = tl.num_programs(1), tl.num_programs(2)
54
+ block_idx = tl.arange(0, HEAD_DIM)
55
+ x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM
56
+ x = tl.load(X + x_base_idx + block_idx)
57
+ freq_idx = tok_idx * HEAD_DIM + block_idx
58
+ cos = tl.load(Cos + freq_idx)
59
+ rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM
60
+ x_rot = tl.load(X + x_base_idx + rot_idx)
61
+ x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot)
62
+ sin = tl.load(Sin + freq_idx)
63
+ y_idx = (
64
+ (batch_idx * seq_len + tok_idx) * num_heads + head_idx
65
+ ) * HEAD_DIM + block_idx
66
+ y = x * cos + x_rot * sin
67
+ tl.store(Y + y_idx, y.to(Y.dtype.element_ty))
68
+
69
+
70
+ apply_rope_fwd_kernel = TritonKernel(
71
+ _apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3])
72
+ )
73
+
74
+
75
+ def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
76
+ y = torch.empty(x.shape, dtype=x.dtype, device=x.device)
77
+ apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1))
78
+ return y
79
+
80
+
81
+ @triton.jit
82
+ def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr):
83
+ tok_idx = tl.program_id(0)
84
+
85
+ mean_sq = tl.zeros([BLOCK_SIZE], tl.float32)
86
+ for offset in range(0, hidden_dim, BLOCK_SIZE):
87
+ dim_idx = offset + tl.arange(0, BLOCK_SIZE)
88
+ x = tl.load(
89
+ X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0
90
+ ).to(tl.float32)
91
+ mean_sq += x * x / hidden_dim
92
+ rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps)
93
+
94
+ for offset in range(0, hidden_dim, BLOCK_SIZE):
95
+ dim_idx = offset + tl.arange(0, BLOCK_SIZE)
96
+ dim_mask = dim_idx < hidden_dim
97
+ hidden_idx = tok_idx * hidden_dim + dim_idx
98
+ x = tl.load(X + hidden_idx, mask=dim_mask, other=0)
99
+ w = tl.load(W + dim_idx, mask=dim_mask, other=0)
100
+ y = x * rrms * w
101
+ tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask)
102
+
103
+
104
+ rms_norm_fwd_kernel = TritonKernel(
105
+ _rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1)
106
+ )
107
+
108
+
109
+ def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
110
+ y = torch.empty_like(x)
111
+ hidden_dim = x.size(-1)
112
+ rms_norm_fwd_kernel.run(
113
+ x, weight, y, eps, hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim)
114
+ )
115
+ return y