Files changed (3) hide show
  1. mha.py +1 -0
  2. modeling_xlm_roberta.py +1 -3
  3. rotary.py +22 -11
mha.py CHANGED
@@ -463,6 +463,7 @@ class MHA(nn.Module):
463
  scale_base=rotary_emb_scale_base,
464
  interleaved=rotary_emb_interleaved,
465
  device=device,
 
466
  )
467
 
468
  if fused_bias_fc and FusedDense is None:
 
463
  scale_base=rotary_emb_scale_base,
464
  interleaved=rotary_emb_interleaved,
465
  device=device,
466
+ use_flash_attn=use_flash_attn,
467
  )
468
 
469
  if fused_bias_fc and FusedDense is None:
modeling_xlm_roberta.py CHANGED
@@ -63,9 +63,7 @@ logger = logging.getLogger(__name__)
63
 
64
 
65
  def get_use_flash_attn(config: XLMRobertaFlashConfig):
66
- if not getattr(config, "use_flash_attn", False):
67
- return False
68
- if not torch.cuda.is_available():
69
  return False
70
  if importlib.util.find_spec("flash_attn") is None:
71
  logger.warning(
 
63
 
64
 
65
  def get_use_flash_attn(config: XLMRobertaFlashConfig):
66
+ if not getattr(config, "use_flash_attn", False) or not torch.cuda.is_available():
 
 
67
  return False
68
  if importlib.util.find_spec("flash_attn") is None:
69
  logger.warning(
rotary.py CHANGED
@@ -4,20 +4,11 @@
4
 
5
  # Copyright (c) 2023, Tri Dao.
6
 
7
- import math
8
  from typing import Optional, Tuple, Union
9
 
10
  import torch
11
  from einops import rearrange, repeat
12
 
13
- if torch.cuda.is_available():
14
- try:
15
- from flash_attn.ops.triton.rotary import apply_rotary
16
- except ImportError:
17
-
18
- def apply_rotary(*args, **kwargs):
19
- raise RuntimeError("RoPE requires flash-attention to be installed")
20
-
21
 
22
  def rotate_half(x, interleaved=False):
23
  if not interleaved:
@@ -69,6 +60,8 @@ class ApplyRotaryEmb(torch.autograd.Function):
69
  cu_seqlens: Optional[torch.Tensor] = None,
70
  max_seqlen: Optional[int] = None,
71
  ):
 
 
72
  out = apply_rotary(
73
  x,
74
  cos,
@@ -95,6 +88,8 @@ class ApplyRotaryEmb(torch.autograd.Function):
95
 
96
  @staticmethod
97
  def backward(ctx, do):
 
 
98
  seqlen_offsets = ctx.seqlen_offsets
99
  if seqlen_offsets is None:
100
  cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
@@ -169,12 +164,15 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
169
  seqlen_offsets: Union[int, torch.Tensor] = 0,
170
  cu_seqlens: Optional[torch.Tensor] = None,
171
  max_seqlen: Optional[int] = None,
 
172
  ):
173
  # batch, seqlen, three, nheads, headdim = qkv.shape
174
  assert qkv.shape[-3] == 3
175
  if cos_k is None and sin_k is None and qkv.is_contiguous():
176
 
177
- if torch.cuda.is_available():
 
 
178
  # Call 1 kernel instead of 2 kernels
179
  # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
180
  # dimensions, we get the same tensor
@@ -205,6 +203,8 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
205
  )
206
  qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
207
  else:
 
 
208
  cos_k = cos if cos_k is None else cos_k
209
  sin_k = sin if sin_k is None else sin_k
210
  q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
@@ -241,6 +241,8 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
241
 
242
  @staticmethod
243
  def backward(ctx, dqkv):
 
 
244
  seqlen_offsets = ctx.seqlen_offsets
245
  if seqlen_offsets is None:
246
  cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
@@ -301,6 +303,7 @@ def apply_rotary_emb_qkv_(
301
  seqlen_offsets: Union[int, torch.Tensor] = 0,
302
  cu_seqlens: Optional[torch.Tensor] = None,
303
  max_seqlen: Optional[int] = None,
 
304
  ):
305
  """
306
  Arguments:
@@ -321,7 +324,7 @@ def apply_rotary_emb_qkv_(
321
  Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
322
  """
323
  return ApplyRotaryEmbQKV_.apply(
324
- qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
325
  )
326
 
327
 
@@ -337,6 +340,8 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
337
  cu_seqlens: Optional[torch.Tensor] = None,
338
  max_seqlen: Optional[int] = None,
339
  ):
 
 
340
  # batch, seqlen, two, nheads, headdim = kv.shape
341
  assert kv.shape[-3] == 2
342
  k = kv[..., 0, :, :]
@@ -364,6 +369,8 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
364
 
365
  @staticmethod
366
  def backward(ctx, dkv):
 
 
367
  seqlen_offsets = ctx.seqlen_offsets
368
  if seqlen_offsets is None:
369
  cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
@@ -443,6 +450,7 @@ class RotaryEmbedding(torch.nn.Module):
443
  scale_base=None,
444
  pos_idx_in_fp32=True,
445
  device=None,
 
446
  ):
447
  """
448
  interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
@@ -462,6 +470,7 @@ class RotaryEmbedding(torch.nn.Module):
462
  self.dim = dim
463
  self._base = float(base)
464
  self.pos_idx_in_fp32 = pos_idx_in_fp32
 
465
  # Generate and save the inverse frequency buffer (non trainable)
466
  inv_freq = self._compute_inv_freq(device)
467
  self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -588,6 +597,7 @@ class RotaryEmbedding(torch.nn.Module):
588
  seqlen_offsets=seqlen_offset,
589
  cu_seqlens=cu_seqlens,
590
  max_seqlen=max_seqlen,
 
591
  )
592
  else:
593
  return apply_rotary_emb_qkv_(
@@ -600,6 +610,7 @@ class RotaryEmbedding(torch.nn.Module):
600
  seqlen_offsets=seqlen_offset,
601
  cu_seqlens=cu_seqlens,
602
  max_seqlen=max_seqlen,
 
603
  )
604
  else:
605
  q = qkv
 
4
 
5
  # Copyright (c) 2023, Tri Dao.
6
 
 
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
10
  from einops import rearrange, repeat
11
 
 
 
 
 
 
 
 
 
12
 
13
  def rotate_half(x, interleaved=False):
14
  if not interleaved:
 
60
  cu_seqlens: Optional[torch.Tensor] = None,
61
  max_seqlen: Optional[int] = None,
62
  ):
63
+ from flash_attn.ops.triton.rotary import apply_rotary
64
+
65
  out = apply_rotary(
66
  x,
67
  cos,
 
88
 
89
  @staticmethod
90
  def backward(ctx, do):
91
+ from flash_attn.ops.triton.rotary import apply_rotary
92
+
93
  seqlen_offsets = ctx.seqlen_offsets
94
  if seqlen_offsets is None:
95
  cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
 
164
  seqlen_offsets: Union[int, torch.Tensor] = 0,
165
  cu_seqlens: Optional[torch.Tensor] = None,
166
  max_seqlen: Optional[int] = None,
167
+ use_flash_attn: bool = True,
168
  ):
169
  # batch, seqlen, three, nheads, headdim = qkv.shape
170
  assert qkv.shape[-3] == 3
171
  if cos_k is None and sin_k is None and qkv.is_contiguous():
172
 
173
+ if use_flash_attn:
174
+ from flash_attn.ops.triton.rotary import apply_rotary
175
+
176
  # Call 1 kernel instead of 2 kernels
177
  # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
178
  # dimensions, we get the same tensor
 
203
  )
204
  qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
205
  else:
206
+ from flash_attn.ops.triton.rotary import apply_rotary
207
+
208
  cos_k = cos if cos_k is None else cos_k
209
  sin_k = sin if sin_k is None else sin_k
210
  q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
 
241
 
242
  @staticmethod
243
  def backward(ctx, dqkv):
244
+ from flash_attn.ops.triton.rotary import apply_rotary
245
+
246
  seqlen_offsets = ctx.seqlen_offsets
247
  if seqlen_offsets is None:
248
  cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
 
303
  seqlen_offsets: Union[int, torch.Tensor] = 0,
304
  cu_seqlens: Optional[torch.Tensor] = None,
305
  max_seqlen: Optional[int] = None,
306
+ use_flash_attn=True,
307
  ):
308
  """
309
  Arguments:
 
324
  Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
325
  """
326
  return ApplyRotaryEmbQKV_.apply(
327
+ qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
328
  )
329
 
330
 
 
340
  cu_seqlens: Optional[torch.Tensor] = None,
341
  max_seqlen: Optional[int] = None,
342
  ):
343
+ from flash_attn.ops.triton.rotary import apply_rotary
344
+
345
  # batch, seqlen, two, nheads, headdim = kv.shape
346
  assert kv.shape[-3] == 2
347
  k = kv[..., 0, :, :]
 
369
 
370
  @staticmethod
371
  def backward(ctx, dkv):
372
+ from flash_attn.ops.triton.rotary import apply_rotary
373
+
374
  seqlen_offsets = ctx.seqlen_offsets
375
  if seqlen_offsets is None:
376
  cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
 
450
  scale_base=None,
451
  pos_idx_in_fp32=True,
452
  device=None,
453
+ use_flash_attn=True,
454
  ):
455
  """
456
  interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
 
470
  self.dim = dim
471
  self._base = float(base)
472
  self.pos_idx_in_fp32 = pos_idx_in_fp32
473
+ self.use_flash_attn = use_flash_attn
474
  # Generate and save the inverse frequency buffer (non trainable)
475
  inv_freq = self._compute_inv_freq(device)
476
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
597
  seqlen_offsets=seqlen_offset,
598
  cu_seqlens=cu_seqlens,
599
  max_seqlen=max_seqlen,
600
+ use_flash_attn=self.use_flash_attn,
601
  )
602
  else:
603
  return apply_rotary_emb_qkv_(
 
610
  seqlen_offsets=seqlen_offset,
611
  cu_seqlens=cu_seqlens,
612
  max_seqlen=max_seqlen,
613
+ use_flash_attn=self.use_flash_attn,
614
  )
615
  else:
616
  q = qkv