Shangming Cai
commited on
Commit
•
af64202
1
Parent(s):
f2191b9
Add ApplyRoPE and RMSNorm kernels written in OpenAI Triton.
Browse files- config.json +1 -0
- configuration_qwen.py +2 -0
- modeling_qwen.py +33 -3
- triton_kernels.py +115 -0
config.json
CHANGED
@@ -44,6 +44,7 @@
|
|
44 |
"use_cache": true,
|
45 |
"use_dynamic_ntk": true,
|
46 |
"use_flash_attn": "auto",
|
|
|
47 |
"use_logn_attn": true,
|
48 |
"vocab_size": 151936
|
49 |
}
|
|
|
44 |
"use_cache": true,
|
45 |
"use_dynamic_ntk": true,
|
46 |
"use_flash_attn": "auto",
|
47 |
+
"use_triton": "auto",
|
48 |
"use_logn_attn": true,
|
49 |
"vocab_size": 151936
|
50 |
}
|
configuration_qwen.py
CHANGED
@@ -32,6 +32,7 @@ class QWenConfig(PretrainedConfig):
|
|
32 |
use_dynamic_ntk=True,
|
33 |
use_logn_attn=True,
|
34 |
use_flash_attn="auto",
|
|
|
35 |
intermediate_size=22016,
|
36 |
no_bias=True,
|
37 |
tie_word_embeddings=False,
|
@@ -61,6 +62,7 @@ class QWenConfig(PretrainedConfig):
|
|
61 |
self.use_dynamic_ntk = use_dynamic_ntk
|
62 |
self.use_logn_attn = use_logn_attn
|
63 |
self.use_flash_attn = use_flash_attn
|
|
|
64 |
self.no_bias = no_bias
|
65 |
self.use_cache_quantization = use_cache_quantization
|
66 |
self.use_cache_kernel = use_cache_kernel
|
|
|
32 |
use_dynamic_ntk=True,
|
33 |
use_logn_attn=True,
|
34 |
use_flash_attn="auto",
|
35 |
+
use_triton="auto",
|
36 |
intermediate_size=22016,
|
37 |
no_bias=True,
|
38 |
tie_word_embeddings=False,
|
|
|
62 |
self.use_dynamic_ntk = use_dynamic_ntk
|
63 |
self.use_logn_attn = use_logn_attn
|
64 |
self.use_flash_attn = use_flash_attn
|
65 |
+
self.use_triton = use_triton
|
66 |
self.no_bias = no_bias
|
67 |
self.use_cache_quantization = use_cache_quantization
|
68 |
self.use_cache_kernel = use_cache_kernel
|
modeling_qwen.py
CHANGED
@@ -35,7 +35,7 @@ except ImportError:
|
|
35 |
from torch import nn
|
36 |
|
37 |
SUPPORT_CUDA = torch.cuda.is_available()
|
38 |
-
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.
|
39 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
40 |
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
41 |
|
@@ -76,7 +76,9 @@ We detect you have activated flash attention support, but running model computat
|
|
76 |
"""
|
77 |
|
78 |
apply_rotary_emb_func = None
|
|
|
79 |
rms_norm = None
|
|
|
80 |
flash_attn_unpadded_func = None
|
81 |
flash_attn_func = None
|
82 |
|
@@ -120,6 +122,24 @@ def _import_flash_attn():
|
|
120 |
"https://github.com/Dao-AILab/flash-attention"
|
121 |
)
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
def quantize_cache_v(fdata, bits, qmax, qmin):
|
124 |
# b, s, head, h-dim->b, head, s, h-dim
|
125 |
qtype = torch.uint8
|
@@ -978,6 +998,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
978 |
if config.use_flash_attn:
|
979 |
_import_flash_attn()
|
980 |
|
|
|
|
|
|
|
|
|
|
|
|
|
981 |
self.transformer = QWenModel(config)
|
982 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
983 |
|
@@ -1335,7 +1361,9 @@ def apply_rotary_pos_emb(t, freqs):
|
|
1335 |
rot_dim = freqs[0].shape[-1]
|
1336 |
cos, sin = freqs
|
1337 |
t_float = t.float()
|
1338 |
-
if
|
|
|
|
|
1339 |
# apply_rotary_emb in flash_attn requires cos/sin to be of
|
1340 |
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
|
1341 |
# to the first rotary_dim of the input
|
@@ -1358,7 +1386,9 @@ class RMSNorm(torch.nn.Module):
|
|
1358 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
1359 |
|
1360 |
def forward(self, x):
|
1361 |
-
if
|
|
|
|
|
1362 |
return rms_norm(x, self.weight, self.eps)
|
1363 |
else:
|
1364 |
output = self._norm(x.float()).type_as(x)
|
|
|
35 |
from torch import nn
|
36 |
|
37 |
SUPPORT_CUDA = torch.cuda.is_available()
|
38 |
+
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 8
|
39 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
40 |
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
41 |
|
|
|
76 |
"""
|
77 |
|
78 |
apply_rotary_emb_func = None
|
79 |
+
apply_rotary_emb_func_triton = None
|
80 |
rms_norm = None
|
81 |
+
rms_norm_triton = None
|
82 |
flash_attn_unpadded_func = None
|
83 |
flash_attn_func = None
|
84 |
|
|
|
122 |
"https://github.com/Dao-AILab/flash-attention"
|
123 |
)
|
124 |
|
125 |
+
def _import_triton():
|
126 |
+
global apply_rotary_emb_func_triton, rms_norm_triton
|
127 |
+
try:
|
128 |
+
from .triton_kernels import apply_rotary_emb as __apply_rotary_emb, rms_norm as __rms_norm
|
129 |
+
if apply_rotary_emb_func is not None:
|
130 |
+
logger.warn(
|
131 |
+
"Using Triton rotary kernel instead of flash_attn for inference."
|
132 |
+
)
|
133 |
+
apply_rotary_emb_func_triton = __apply_rotary_emb
|
134 |
+
if rms_norm is not None:
|
135 |
+
logger.warn(
|
136 |
+
"Using Triton rms_norm kernel instead of flash_attn for inference."
|
137 |
+
)
|
138 |
+
rms_norm_triton = __rms_norm
|
139 |
+
except ImportError:
|
140 |
+
logger.warn("Warning: Failed to import Triton kernels.")
|
141 |
+
return
|
142 |
+
|
143 |
def quantize_cache_v(fdata, bits, qmax, qmin):
|
144 |
# b, s, head, h-dim->b, head, s, h-dim
|
145 |
qtype = torch.uint8
|
|
|
998 |
if config.use_flash_attn:
|
999 |
_import_flash_attn()
|
1000 |
|
1001 |
+
if config.use_triton == "auto":
|
1002 |
+
logger.warn("Try importing Triton kernels for faster inference...")
|
1003 |
+
config.use_triton = SUPPORT_TORCH2
|
1004 |
+
if config.use_triton:
|
1005 |
+
_import_triton()
|
1006 |
+
|
1007 |
self.transformer = QWenModel(config)
|
1008 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1009 |
|
|
|
1361 |
rot_dim = freqs[0].shape[-1]
|
1362 |
cos, sin = freqs
|
1363 |
t_float = t.float()
|
1364 |
+
if apply_rotary_emb_func_triton is not None and t.is_cuda and (not t.requires_grad):
|
1365 |
+
return apply_rotary_emb_func_triton(t, cos, sin)
|
1366 |
+
elif apply_rotary_emb_func is not None and t.is_cuda:
|
1367 |
# apply_rotary_emb in flash_attn requires cos/sin to be of
|
1368 |
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
|
1369 |
# to the first rotary_dim of the input
|
|
|
1386 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
1387 |
|
1388 |
def forward(self, x):
|
1389 |
+
if rms_norm_triton is not None and x.is_cuda and (not x.requires_grad):
|
1390 |
+
return rms_norm_triton(x, self.weight, self.eps)
|
1391 |
+
elif rms_norm is not None and x.is_cuda:
|
1392 |
return rms_norm(x, self.weight, self.eps)
|
1393 |
else:
|
1394 |
output = self._norm(x.float()).type_as(x)
|
triton_kernels.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict, Hashable, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
from triton.compiler import CompiledKernel
|
7 |
+
from triton.runtime import JITFunction
|
8 |
+
|
9 |
+
try:
|
10 |
+
import triton.language.math as tlmath # Triton 2.1
|
11 |
+
except ImportError:
|
12 |
+
import triton.language.libdevice as tlmath # Triton 2.0
|
13 |
+
|
14 |
+
|
15 |
+
class TritonKernel:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
kernel_fn: JITFunction,
|
19 |
+
grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]],
|
20 |
+
) -> None:
|
21 |
+
self.kernel_fn_ = kernel_fn
|
22 |
+
self.grid_fn_ = grid_fn
|
23 |
+
self.kernel_cache_: Dict[Hashable, CompiledKernel] = {}
|
24 |
+
|
25 |
+
def run(self, *args, **kwargs):
|
26 |
+
# Set current device
|
27 |
+
input_device = args[0].device
|
28 |
+
prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device()
|
29 |
+
if input_device.index != cur_dev_idx:
|
30 |
+
prev_dev_idx = cur_dev_idx
|
31 |
+
torch.cuda.set_device(input_device.index)
|
32 |
+
|
33 |
+
# Compute grid
|
34 |
+
grid = self.grid_fn_(args)
|
35 |
+
|
36 |
+
# Use cached kernel if possible
|
37 |
+
kernel_key = (input_device,) + tuple(kwargs.items())
|
38 |
+
if kernel_key in self.kernel_cache_:
|
39 |
+
kernel = self.kernel_cache_[kernel_key]
|
40 |
+
kernel[grid](*args)
|
41 |
+
else:
|
42 |
+
# Compile and store new kernel
|
43 |
+
kernel = self.kernel_fn_[grid](*args, **kwargs)
|
44 |
+
self.kernel_cache_[kernel_key] = kernel
|
45 |
+
|
46 |
+
# Restore previous device
|
47 |
+
torch.cuda.set_device(prev_dev_idx)
|
48 |
+
|
49 |
+
|
50 |
+
@triton.jit
|
51 |
+
def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr):
|
52 |
+
batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
53 |
+
seq_len, num_heads = tl.num_programs(1), tl.num_programs(2)
|
54 |
+
block_idx = tl.arange(0, HEAD_DIM)
|
55 |
+
x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM
|
56 |
+
x = tl.load(X + x_base_idx + block_idx)
|
57 |
+
freq_idx = tok_idx * HEAD_DIM + block_idx
|
58 |
+
cos = tl.load(Cos + freq_idx)
|
59 |
+
rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM
|
60 |
+
x_rot = tl.load(X + x_base_idx + rot_idx)
|
61 |
+
x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot)
|
62 |
+
sin = tl.load(Sin + freq_idx)
|
63 |
+
y_idx = (
|
64 |
+
(batch_idx * seq_len + tok_idx) * num_heads + head_idx
|
65 |
+
) * HEAD_DIM + block_idx
|
66 |
+
y = x * cos + x_rot * sin
|
67 |
+
tl.store(Y + y_idx, y.to(Y.dtype.element_ty))
|
68 |
+
|
69 |
+
|
70 |
+
apply_rope_fwd_kernel = TritonKernel(
|
71 |
+
_apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3])
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
76 |
+
y = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
77 |
+
apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1))
|
78 |
+
return y
|
79 |
+
|
80 |
+
|
81 |
+
@triton.jit
|
82 |
+
def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr):
|
83 |
+
tok_idx = tl.program_id(0)
|
84 |
+
|
85 |
+
mean_sq = tl.zeros([BLOCK_SIZE], tl.float32)
|
86 |
+
for offset in range(0, hidden_dim, BLOCK_SIZE):
|
87 |
+
dim_idx = offset + tl.arange(0, BLOCK_SIZE)
|
88 |
+
x = tl.load(
|
89 |
+
X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0
|
90 |
+
).to(tl.float32)
|
91 |
+
mean_sq += x * x / hidden_dim
|
92 |
+
rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps)
|
93 |
+
|
94 |
+
for offset in range(0, hidden_dim, BLOCK_SIZE):
|
95 |
+
dim_idx = offset + tl.arange(0, BLOCK_SIZE)
|
96 |
+
dim_mask = dim_idx < hidden_dim
|
97 |
+
hidden_idx = tok_idx * hidden_dim + dim_idx
|
98 |
+
x = tl.load(X + hidden_idx, mask=dim_mask, other=0)
|
99 |
+
w = tl.load(W + dim_idx, mask=dim_mask, other=0)
|
100 |
+
y = x * rrms * w
|
101 |
+
tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask)
|
102 |
+
|
103 |
+
|
104 |
+
rms_norm_fwd_kernel = TritonKernel(
|
105 |
+
_rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1)
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
|
110 |
+
y = torch.empty_like(x)
|
111 |
+
hidden_dim = x.size(-1)
|
112 |
+
rms_norm_fwd_kernel.run(
|
113 |
+
x, weight, y, eps, hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim)
|
114 |
+
)
|
115 |
+
return y
|