jupyterjazz commited on
Commit
7ad815b
1 Parent(s): 4d09ca8

no-flash-attention-during-inference (#22)

Browse files

- feat: no flash attention during inference (e3423c0243fe029474cbf283046f02e11e82202f)

Files changed (1) hide show
  1. rotary.py +40 -21
rotary.py CHANGED
@@ -6,11 +6,13 @@ from typing import Optional, Tuple, Union
6
 
7
  import torch
8
  from einops import rearrange, repeat
9
- try:
10
- from flash_attn.ops.triton.rotary import apply_rotary
11
- except ImportError:
12
- def apply_rotary(*args, **kwargs):
13
- raise RuntimeError('RoPE requires flash-attention to be installed')
 
 
14
 
15
 
16
  def rotate_half(x, interleaved=False):
@@ -60,6 +62,7 @@ class ApplyRotaryEmb(torch.autograd.Function):
60
  interleaved=interleaved,
61
  inplace=inplace,
62
  )
 
63
  if isinstance(seqlen_offsets, int):
64
  ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
65
  ctx.seqlen_offsets = seqlen_offsets
@@ -82,6 +85,7 @@ class ApplyRotaryEmb(torch.autograd.Function):
82
  # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
83
  if not ctx.interleaved and not ctx.inplace:
84
  do = do.clone()
 
85
  dx = apply_rotary(
86
  do,
87
  cos,
@@ -150,21 +154,37 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
150
  # batch, seqlen, three, nheads, headdim = qkv.shape
151
  assert qkv.shape[-3] == 3
152
  if cos_k is None and sin_k is None and qkv.is_contiguous():
153
- # Call 1 kernel instead of 2 kernels
154
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
155
- # dimensions, we get the same tensor
156
- qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
157
- # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
158
- apply_rotary(
159
- qk,
160
- cos,
161
- sin,
162
- seqlen_offsets=seqlen_offsets,
163
- interleaved=interleaved,
164
- inplace=True,
165
- cu_seqlens=cu_seqlens,
166
- max_seqlen=max_seqlen,
167
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  else:
169
  cos_k = cos if cos_k is None else cos_k
170
  sin_k = sin if sin_k is None else sin_k
@@ -228,7 +248,6 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
228
  sin_k = sin if sin_k is None else sin_k
229
  dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
230
  apply_rotary(
231
-
232
  dq,
233
  cos,
234
  sin,
 
6
 
7
  import torch
8
  from einops import rearrange, repeat
9
+
10
+ if torch.cuda.is_available():
11
+ try:
12
+ from flash_attn.ops.triton.rotary import apply_rotary
13
+ except ImportError:
14
+ def apply_rotary(*args, **kwargs):
15
+ raise RuntimeError('RoPE requires flash-attention to be installed')
16
 
17
 
18
  def rotate_half(x, interleaved=False):
 
62
  interleaved=interleaved,
63
  inplace=inplace,
64
  )
65
+
66
  if isinstance(seqlen_offsets, int):
67
  ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
68
  ctx.seqlen_offsets = seqlen_offsets
 
85
  # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
86
  if not ctx.interleaved and not ctx.inplace:
87
  do = do.clone()
88
+
89
  dx = apply_rotary(
90
  do,
91
  cos,
 
154
  # batch, seqlen, three, nheads, headdim = qkv.shape
155
  assert qkv.shape[-3] == 3
156
  if cos_k is None and sin_k is None and qkv.is_contiguous():
157
+
158
+ if torch.cuda.is_available():
159
+ # Call 1 kernel instead of 2 kernels
160
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
161
+ # dimensions, we get the same tensor
162
+ qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
163
+ # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
164
+ apply_rotary(
165
+ qk,
166
+ cos,
167
+ sin,
168
+ seqlen_offsets=seqlen_offsets,
169
+ interleaved=interleaved,
170
+ inplace=True,
171
+ cu_seqlens=cu_seqlens,
172
+ max_seqlen=max_seqlen,
173
+ )
174
+ else:
175
+ q_rot = apply_rotary_emb_torch(
176
+ qkv[:, :, 0],
177
+ cos,
178
+ sin,
179
+ interleaved=interleaved,
180
+ )
181
+ k_rot = apply_rotary_emb_torch(
182
+ qkv[:, :, 1],
183
+ cos,
184
+ sin,
185
+ interleaved=interleaved,
186
+ )
187
+ qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
188
  else:
189
  cos_k = cos if cos_k is None else cos_k
190
  sin_k = sin if sin_k is None else sin_k
 
248
  sin_k = sin if sin_k is None else sin_k
249
  dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
250
  apply_rotary(
 
251
  dq,
252
  cos,
253
  sin,