jupyterjazz commited on
Commit
f2e0e62
1 Parent(s): ab85772

feat: support rope

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (3) hide show
  1. mha.py +332 -44
  2. modeling_xlm_roberta.py +2 -2
  3. rotary.py +570 -0
mha.py CHANGED
@@ -1,6 +1,3 @@
1
- # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
2
- # Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
3
-
4
  # Copyright (c) 2023, Tri Dao.
5
 
6
  import math
@@ -10,6 +7,8 @@ import torch
10
  import torch.nn as nn
11
  from einops import rearrange, repeat
12
 
 
 
13
  try:
14
  from flash_attn import (
15
  flash_attn_kvpacked_func,
@@ -28,10 +27,7 @@ try:
28
  except ImportError:
29
  FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
30
 
31
- try:
32
- from flash_attn.layers.rotary import RotaryEmbedding
33
- except ImportError:
34
- RotaryEmbedding = None
35
 
36
 
37
  # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
@@ -62,15 +58,7 @@ class FlashSelfAttention(nn.Module):
62
  (default: 0.0)
63
  """
64
 
65
- def __init__(
66
- self,
67
- causal=False,
68
- softmax_scale=None,
69
- attention_dropout=0.0,
70
- window_size=(-1, -1),
71
- alibi_slopes=None,
72
- deterministic=False,
73
- ):
74
  super().__init__()
75
  assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
76
  assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
@@ -78,7 +66,6 @@ class FlashSelfAttention(nn.Module):
78
  self.softmax_scale = softmax_scale
79
  self.drop = nn.Dropout(attention_dropout)
80
  self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
81
- self.window_size = window_size
82
  self.deterministic = deterministic
83
 
84
  def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
@@ -102,8 +89,6 @@ class FlashSelfAttention(nn.Module):
102
  assert qkv.is_cuda
103
  causal = self.causal if causal is None else causal
104
  unpadded = cu_seqlens is not None
105
- if self.alibi_slopes is not None:
106
- self.alibi_slopes = self.alibi_slopes.to(torch.float32)
107
  if unpadded:
108
  assert cu_seqlens.dtype == torch.int32
109
  assert max_seqlen is not None
@@ -116,7 +101,6 @@ class FlashSelfAttention(nn.Module):
116
  softmax_scale=self.softmax_scale,
117
  causal=causal,
118
  alibi_slopes=self.alibi_slopes,
119
- window_size=self.window_size,
120
  deterministic=self.deterministic,
121
  )
122
  else:
@@ -126,7 +110,6 @@ class FlashSelfAttention(nn.Module):
126
  softmax_scale=self.softmax_scale,
127
  causal=causal,
128
  alibi_slopes=self.alibi_slopes,
129
- window_size=self.window_size,
130
  deterministic=self.deterministic,
131
  )
132
 
@@ -142,15 +125,7 @@ class FlashCrossAttention(nn.Module):
142
  (default: 0.0)
143
  """
144
 
145
- def __init__(
146
- self,
147
- causal=False,
148
- softmax_scale=None,
149
- attention_dropout=0.0,
150
- alibi_slopes=None,
151
- window_size=(-1, -1),
152
- deterministic=False,
153
- ):
154
  super().__init__()
155
  assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
156
  assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
@@ -158,7 +133,6 @@ class FlashCrossAttention(nn.Module):
158
  self.softmax_scale = softmax_scale
159
  self.drop = nn.Dropout(attention_dropout)
160
  self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
161
- self.window_size = window_size
162
  self.deterministic = deterministic
163
 
164
  def forward(
@@ -188,8 +162,6 @@ class FlashCrossAttention(nn.Module):
188
  assert q.is_cuda and kv.is_cuda
189
  causal = self.causal if causal is None else causal
190
  unpadded = cu_seqlens is not None
191
- if self.alibi_slopes is not None:
192
- self.alibi_slopes = self.alibi_slopes.to(torch.float32)
193
  if unpadded:
194
  assert cu_seqlens.dtype == torch.int32
195
  assert max_seqlen is not None
@@ -209,7 +181,6 @@ class FlashCrossAttention(nn.Module):
209
  softmax_scale=self.softmax_scale,
210
  causal=causal,
211
  alibi_slopes=self.alibi_slopes,
212
- window_size=self.window_size,
213
  deterministic=self.deterministic,
214
  )
215
  else:
@@ -223,7 +194,6 @@ class FlashCrossAttention(nn.Module):
223
  causal=causal,
224
  softmax_scale=self.softmax_scale,
225
  alibi_slopes=self.alibi_slopes,
226
- window_size=self.window_size,
227
  deterministic=self.deterministic,
228
  )
229
 
@@ -399,7 +369,6 @@ class MHA(nn.Module):
399
  rotary_emb_scale_base=None,
400
  rotary_emb_interleaved=False,
401
  use_alibi=False,
402
- window_size=(-1, -1),
403
  fused_bias_fc=False,
404
  use_flash_attn=False,
405
  return_residual=False,
@@ -429,8 +398,6 @@ class MHA(nn.Module):
429
  alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
430
  else:
431
  alibi_slopes = None
432
- if window_size != (-1, -1):
433
- assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
434
 
435
  self.num_heads = num_heads
436
  self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
@@ -461,12 +428,12 @@ class MHA(nn.Module):
461
  )
462
  wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
463
  inner_attn_cls = (
464
- partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
465
  if use_flash_attn
466
  else SelfAttention
467
  )
468
  inner_cross_attn_cls = (
469
- partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
470
  if use_flash_attn
471
  else CrossAttention
472
  )
@@ -619,7 +586,7 @@ class MHA(nn.Module):
619
  assert key_padding_mask is None
620
  assert self.use_flash_attn
621
  assert not self.dwconv
622
- assert self.rotary_emb_dim == 0
623
  if key_padding_mask is not None:
624
  assert cu_seqlens is None
625
  assert max_seqlen is None
@@ -643,7 +610,9 @@ class MHA(nn.Module):
643
  else inference_params.seqlen_offset
644
  )
645
  )
646
- rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
 
 
647
  batch, seqlen = x.shape[:2]
648
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
649
  assert x_kv is None and mixer_subset is None
@@ -664,7 +633,10 @@ class MHA(nn.Module):
664
  ):
665
  if self.rotary_emb_dim > 0:
666
  qkv = self.rotary_emb(
667
- qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
 
 
 
668
  )
669
  if inference_params is None:
670
  if not self.checkpointing:
@@ -715,7 +687,11 @@ class MHA(nn.Module):
715
  ):
716
  if self.rotary_emb_dim > 0:
717
  q, kv = self.rotary_emb(
718
- q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
 
 
 
 
719
  )
720
  if inference_params is None:
721
  if not self.checkpointing:
@@ -731,3 +707,315 @@ class MHA(nn.Module):
731
  out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
732
  return out if not self.return_residual else (out, x)
733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Copyright (c) 2023, Tri Dao.
2
 
3
  import math
 
7
  import torch.nn as nn
8
  from einops import rearrange, repeat
9
 
10
+ from flash_attn.utils.distributed import get_dim_for_local_rank
11
+
12
  try:
13
  from flash_attn import (
14
  flash_attn_kvpacked_func,
 
27
  except ImportError:
28
  FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
29
 
30
+ from .rotary import RotaryEmbedding
 
 
 
31
 
32
 
33
  # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
 
58
  (default: 0.0)
59
  """
60
 
61
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
 
 
 
 
 
 
 
 
62
  super().__init__()
63
  assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
64
  assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
 
66
  self.softmax_scale = softmax_scale
67
  self.drop = nn.Dropout(attention_dropout)
68
  self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
 
69
  self.deterministic = deterministic
70
 
71
  def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
 
89
  assert qkv.is_cuda
90
  causal = self.causal if causal is None else causal
91
  unpadded = cu_seqlens is not None
 
 
92
  if unpadded:
93
  assert cu_seqlens.dtype == torch.int32
94
  assert max_seqlen is not None
 
101
  softmax_scale=self.softmax_scale,
102
  causal=causal,
103
  alibi_slopes=self.alibi_slopes,
 
104
  deterministic=self.deterministic,
105
  )
106
  else:
 
110
  softmax_scale=self.softmax_scale,
111
  causal=causal,
112
  alibi_slopes=self.alibi_slopes,
 
113
  deterministic=self.deterministic,
114
  )
115
 
 
125
  (default: 0.0)
126
  """
127
 
128
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, deterministic=False):
 
 
 
 
 
 
 
 
129
  super().__init__()
130
  assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
131
  assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
 
133
  self.softmax_scale = softmax_scale
134
  self.drop = nn.Dropout(attention_dropout)
135
  self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
 
136
  self.deterministic = deterministic
137
 
138
  def forward(
 
162
  assert q.is_cuda and kv.is_cuda
163
  causal = self.causal if causal is None else causal
164
  unpadded = cu_seqlens is not None
 
 
165
  if unpadded:
166
  assert cu_seqlens.dtype == torch.int32
167
  assert max_seqlen is not None
 
181
  softmax_scale=self.softmax_scale,
182
  causal=causal,
183
  alibi_slopes=self.alibi_slopes,
 
184
  deterministic=self.deterministic,
185
  )
186
  else:
 
194
  causal=causal,
195
  softmax_scale=self.softmax_scale,
196
  alibi_slopes=self.alibi_slopes,
 
197
  deterministic=self.deterministic,
198
  )
199
 
 
369
  rotary_emb_scale_base=None,
370
  rotary_emb_interleaved=False,
371
  use_alibi=False,
 
372
  fused_bias_fc=False,
373
  use_flash_attn=False,
374
  return_residual=False,
 
398
  alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
399
  else:
400
  alibi_slopes = None
 
 
401
 
402
  self.num_heads = num_heads
403
  self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
 
428
  )
429
  wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
430
  inner_attn_cls = (
431
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
432
  if use_flash_attn
433
  else SelfAttention
434
  )
435
  inner_cross_attn_cls = (
436
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
437
  if use_flash_attn
438
  else CrossAttention
439
  )
 
586
  assert key_padding_mask is None
587
  assert self.use_flash_attn
588
  assert not self.dwconv
589
+ # assert self.rotary_emb_dim == 0
590
  if key_padding_mask is not None:
591
  assert cu_seqlens is None
592
  assert max_seqlen is None
 
610
  else inference_params.seqlen_offset
611
  )
612
  )
613
+ rotary_max_seqlen = (
614
+ inference_params.max_sequence_len if inference_params is not None else max_seqlen
615
+ )
616
  batch, seqlen = x.shape[:2]
617
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
618
  assert x_kv is None and mixer_subset is None
 
633
  ):
634
  if self.rotary_emb_dim > 0:
635
  qkv = self.rotary_emb(
636
+ qkv,
637
+ seqlen_offset=seqlen_offset,
638
+ cu_seqlens=cu_seqlens,
639
+ max_seqlen=rotary_max_seqlen,
640
  )
641
  if inference_params is None:
642
  if not self.checkpointing:
 
687
  ):
688
  if self.rotary_emb_dim > 0:
689
  q, kv = self.rotary_emb(
690
+ q,
691
+ kv,
692
+ seqlen_offset=seqlen_offset,
693
+ cu_seqlens=cu_seqlens,
694
+ max_seqlen=rotary_max_seqlen,
695
  )
696
  if inference_params is None:
697
  if not self.checkpointing:
 
707
  out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
708
  return out if not self.return_residual else (out, x)
709
 
710
+
711
+ class ParallelMHA(nn.Module):
712
+ """Multi-head self-attention and cross-attention"""
713
+
714
+ def __init__(
715
+ self,
716
+ embed_dim,
717
+ num_heads,
718
+ process_group,
719
+ num_heads_kv=None,
720
+ qkv_proj_bias=True,
721
+ out_proj_bias=True,
722
+ dropout=0.0,
723
+ softmax_scale=None,
724
+ causal=False,
725
+ layer_idx=None,
726
+ rotary_emb_dim=0,
727
+ rotary_emb_base=10000.0,
728
+ rotary_emb_scale_base=None,
729
+ rotary_emb_interleaved=False,
730
+ use_alibi=False,
731
+ use_flash_attn=False,
732
+ checkpointing=False,
733
+ sequence_parallel=True,
734
+ device=None,
735
+ dtype=None,
736
+ ) -> None:
737
+ factory_kwargs = {"device": device, "dtype": dtype}
738
+ super().__init__()
739
+ self.embed_dim = embed_dim
740
+ self.causal = causal
741
+ self.layer_idx = layer_idx
742
+ self.rotary_emb_dim = rotary_emb_dim
743
+ self.use_flash_attn = use_flash_attn
744
+ self.checkpointing = checkpointing
745
+ self.process_group = process_group
746
+ self.world_size = process_group.size()
747
+ self.local_rank = torch.distributed.get_rank(process_group)
748
+
749
+ self.num_heads = num_heads
750
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
751
+
752
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
753
+ assert (
754
+ self.num_heads % self.num_heads_kv == 0
755
+ ), "num_heads must be divisible by num_heads_kv"
756
+
757
+ self.num_heads_per_rank = get_dim_for_local_rank(
758
+ self.num_heads, self.world_size, self.local_rank
759
+ )
760
+ self.num_heads_kv_per_rank = get_dim_for_local_rank(
761
+ self.num_heads_kv, self.world_size, self.local_rank
762
+ )
763
+ self.head_dim = self.embed_dim // num_heads
764
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
765
+
766
+ if use_alibi:
767
+ assert use_flash_attn, "ALiBi code path requires flash_attn"
768
+ num_heads_local = math.ceil(self.num_heads / self.world_size)
769
+ alibi_slopes = torch.tensor(
770
+ get_alibi_slopes(num_heads)[
771
+ self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
772
+ ],
773
+ device=device,
774
+ )
775
+ else:
776
+ alibi_slopes = None
777
+
778
+ if self.rotary_emb_dim > 0:
779
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
780
+ self.rotary_emb = RotaryEmbedding(
781
+ self.rotary_emb_dim,
782
+ base=rotary_emb_base,
783
+ scale_base=rotary_emb_scale_base,
784
+ interleaved=rotary_emb_interleaved,
785
+ device=device,
786
+ )
787
+
788
+ if ColumnParallelLinear is None or RowParallelLinear is None:
789
+ raise ImportError("fused_dense is not installed")
790
+ self.Wqkv = ColumnParallelLinear(
791
+ embed_dim,
792
+ qkv_dim,
793
+ process_group,
794
+ bias=qkv_proj_bias,
795
+ sequence_parallel=sequence_parallel,
796
+ multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
797
+ **factory_kwargs,
798
+ )
799
+ inner_attn_cls = (
800
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
801
+ if use_flash_attn
802
+ else SelfAttention
803
+ )
804
+ inner_cross_attn_cls = (
805
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
806
+ if use_flash_attn
807
+ else CrossAttention
808
+ )
809
+ self.inner_attn = inner_attn_cls(
810
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
811
+ )
812
+ self.inner_cross_attn = inner_cross_attn_cls(
813
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
814
+ )
815
+ self.out_proj = RowParallelLinear(
816
+ embed_dim,
817
+ embed_dim,
818
+ process_group,
819
+ bias=out_proj_bias,
820
+ sequence_parallel=sequence_parallel,
821
+ multiple_of=self.head_dim,
822
+ **factory_kwargs,
823
+ )
824
+
825
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
826
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
827
+ device = self.out_proj.weight.device
828
+ return torch.empty(
829
+ batch_size,
830
+ max_seqlen,
831
+ 2,
832
+ self.num_heads_kv_per_rank,
833
+ self.head_dim,
834
+ dtype=dtype,
835
+ device=device,
836
+ )
837
+
838
+ def _update_kv_cache(self, kv, inference_params):
839
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
840
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
841
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
842
+
843
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
844
+ """
845
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
846
+ q: (batch_size, seqlen_q, nheads, head_dim)
847
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
848
+ """
849
+ assert inference_params is not None and inference_params.seqlen_offset > 0
850
+ assert self.use_flash_attn
851
+ if self.rotary_emb_dim > 0:
852
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
853
+ self.rotary_emb._update_cos_sin_cache(
854
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
855
+ )
856
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
857
+ else:
858
+ rotary_cos, rotary_sin = None, None
859
+ batch = q.shape[0]
860
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
861
+ cache_seqlens = (
862
+ inference_params.lengths_per_sample[:batch]
863
+ if inference_params.lengths_per_sample is not None
864
+ else inference_params.seqlen_offset
865
+ )
866
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
867
+ context = flash_attn_with_kvcache(
868
+ q,
869
+ kv_cache[:, :, 0],
870
+ kv_cache[:, :, 1],
871
+ kv[:, :, 0],
872
+ kv[:, :, 1],
873
+ rotary_cos=rotary_cos,
874
+ rotary_sin=rotary_sin,
875
+ cache_seqlens=cache_seqlens,
876
+ softmax_scale=self.inner_cross_attn.softmax_scale,
877
+ causal=self.inner_cross_attn.causal,
878
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
879
+ alibi_slopes=alibi_slopes,
880
+ )
881
+ return context
882
+
883
+ def _update_kvcache_attention(self, q, kv, inference_params):
884
+ """Write kv to inference_params, then do attention"""
885
+ if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
886
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
887
+ kv = self._update_kv_cache(kv, inference_params)
888
+ return self.inner_cross_attn(q, kv)
889
+ else:
890
+ batch = q.shape[0]
891
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
892
+ cache_seqlens = (
893
+ inference_params.lengths_per_sample[:batch]
894
+ if inference_params.lengths_per_sample is not None
895
+ else inference_params.seqlen_offset
896
+ )
897
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
898
+ context = flash_attn_with_kvcache(
899
+ q,
900
+ kv_cache[:, :, 0],
901
+ kv_cache[:, :, 1],
902
+ kv[:, :, 0],
903
+ kv[:, :, 1],
904
+ cache_seqlens=cache_seqlens,
905
+ softmax_scale=self.inner_cross_attn.softmax_scale,
906
+ causal=self.inner_cross_attn.causal,
907
+ alibi_slopes=alibi_slopes,
908
+ )
909
+ return context
910
+
911
+ def forward(
912
+ self, x, seqlen=None, inference_params=None, cu_seqlens=None, max_seqlen=None, **kwargs
913
+ ):
914
+ """
915
+ Arguments:
916
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None and cu_seqlens=None.
917
+ (seqlen, hidden_dim) if cu_seqlens not None, seqlen equal cu_seqlens[-1].
918
+ If seqlen is not None and cu_seqlens=None, x is (batch * seqlen, hidden_dim). This is so that when we
919
+ split x during sequence parallel, we split the batch * seqlen dimension
920
+ (in case batch is small).
921
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
922
+ of the sequences in the batch, used to index into x. Only applicable when using
923
+ FlashAttention.
924
+ max_seqlen: int. Maximum sequence length in the batch.
925
+ """
926
+ if cu_seqlens is not None:
927
+ assert max_seqlen is not None
928
+ assert seqlen is None
929
+ assert self.use_flash_attn
930
+ if inference_params is not None:
931
+ assert cu_seqlens is None and max_seqlen is None
932
+ qkv = self.Wqkv(x)
933
+ if seqlen is not None:
934
+ qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
935
+ kwargs = (
936
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
937
+ if self.use_flash_attn
938
+ else kwargs
939
+ )
940
+ seqlen_offset = (
941
+ 0
942
+ if inference_params is None
943
+ else (
944
+ inference_params.lengths_per_sample
945
+ if inference_params.lengths_per_sample is not None
946
+ else inference_params.seqlen_offset
947
+ )
948
+ )
949
+ rotary_max_seqlen = (
950
+ inference_params.max_sequence_len if inference_params is not None else max_seqlen
951
+ )
952
+ if self.num_heads_kv == self.num_heads:
953
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
954
+ if (
955
+ inference_params is None
956
+ or inference_params.seqlen_offset == 0
957
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
958
+ or not self.use_flash_attn
959
+ ):
960
+ if self.rotary_emb_dim > 0:
961
+ qkv = self.rotary_emb(
962
+ qkv,
963
+ seqlen_offset=seqlen_offset,
964
+ cu_seqlens=cu_seqlens,
965
+ max_seqlen=rotary_max_seqlen,
966
+ )
967
+ if inference_params is None:
968
+ if not self.checkpointing:
969
+ context = self.inner_attn(qkv, **kwargs)
970
+ else:
971
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
972
+ else:
973
+ context = self._update_kvcache_attention(
974
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
975
+ )
976
+ else:
977
+ context = self._apply_rotary_update_kvcache_attention(
978
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
979
+ )
980
+ else:
981
+ q = rearrange(
982
+ qkv[..., : self.num_heads_per_rank * self.head_dim],
983
+ "... (h d) -> ... h d",
984
+ d=self.head_dim,
985
+ )
986
+ kv = rearrange(
987
+ qkv[..., self.num_heads_per_rank * self.head_dim :],
988
+ "... (two hkv d) -> ... two hkv d",
989
+ two=2,
990
+ d=self.head_dim,
991
+ )
992
+ if (
993
+ inference_params is None
994
+ or inference_params.seqlen_offset == 0
995
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
996
+ or not self.use_flash_attn
997
+ ):
998
+ if self.rotary_emb_dim > 0:
999
+ q, kv = self.rotary_emb(
1000
+ q,
1001
+ kv,
1002
+ seqlen_offset=seqlen_offset,
1003
+ cu_seqlens=cu_seqlens,
1004
+ max_seqlen=rotary_max_seqlen,
1005
+ )
1006
+ if inference_params is None:
1007
+ if not self.checkpointing:
1008
+ context = self.inner_cross_attn(q, kv, **kwargs)
1009
+ else:
1010
+ context = torch.utils.checkpoint.checkpoint(
1011
+ self.inner_cross_attn, q, kv, **kwargs
1012
+ )
1013
+ else:
1014
+ context = self._update_kvcache_attention(q, kv, inference_params)
1015
+ else:
1016
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
1017
+ context = rearrange(context, "... h d -> ... (h d)")
1018
+ if seqlen is not None:
1019
+ context = rearrange(context, "b s d -> (b s) d")
1020
+ out = self.out_proj(context)
1021
+ return out
modeling_xlm_roberta.py CHANGED
@@ -45,7 +45,7 @@ from .embedding import XLMRobertaEmbeddings
45
  from .mha import MHA
46
  from .mlp import FusedMLP, Mlp
47
  from .stochastic_depth import StochasticDepth
48
-
49
 
50
  try:
51
  from flash_attn.ops.fused_dense import FusedDense
@@ -91,7 +91,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
91
  rotary_kwargs = {}
92
  if config.position_embedding_type == "rotary":
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
- config, "rotary_emb_dim", config.hidden_size
95
  )
96
  rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
 
45
  from .mha import MHA
46
  from .mlp import FusedMLP, Mlp
47
  from .stochastic_depth import StochasticDepth
48
+ from .rotary import RotaryEmbedding
49
 
50
  try:
51
  from flash_attn.ops.fused_dense import FusedDense
 
91
  rotary_kwargs = {}
92
  if config.position_embedding_type == "rotary":
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
+ config, "rotary_emb_dim", config.hidden_size / 12
95
  )
96
  rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
rotary.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import math
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ from einops import rearrange, repeat
8
+ from flash_attn.ops.triton.rotary import apply_rotary
9
+
10
+
11
+ def rotate_half(x, interleaved=False):
12
+ if not interleaved:
13
+ x1, x2 = x.chunk(2, dim=-1)
14
+ return torch.cat((-x2, x1), dim=-1)
15
+ else:
16
+ x1, x2 = x[..., ::2], x[..., 1::2]
17
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
18
+
19
+
20
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
21
+ """
22
+ x: (batch_size, seqlen, nheads, headdim)
23
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
24
+ """
25
+ ro_dim = cos.shape[-1] * 2
26
+ assert ro_dim <= x.shape[-1]
27
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
28
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
29
+ return torch.cat(
30
+ [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
31
+ dim=-1,
32
+ )
33
+
34
+
35
+ class ApplyRotaryEmb(torch.autograd.Function):
36
+ @staticmethod
37
+ def forward(
38
+ ctx,
39
+ x,
40
+ cos,
41
+ sin,
42
+ interleaved=False,
43
+ inplace=False,
44
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
45
+ cu_seqlens: Optional[torch.Tensor] = None,
46
+ max_seqlen: Optional[int] = None,
47
+ ):
48
+ out = apply_rotary(
49
+ x,
50
+ cos,
51
+ sin,
52
+ seqlen_offsets=seqlen_offsets,
53
+ cu_seqlens=cu_seqlens,
54
+ max_seqlen=max_seqlen,
55
+ interleaved=interleaved,
56
+ inplace=inplace,
57
+ )
58
+ if isinstance(seqlen_offsets, int):
59
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
60
+ ctx.seqlen_offsets = seqlen_offsets
61
+ else:
62
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
63
+ ctx.seqlen_offsets = None
64
+ ctx.interleaved = interleaved
65
+ ctx.inplace = inplace
66
+ ctx.max_seqlen = max_seqlen
67
+ return out if not inplace else x
68
+
69
+ @staticmethod
70
+ def backward(ctx, do):
71
+ seqlen_offsets = ctx.seqlen_offsets
72
+ if seqlen_offsets is None:
73
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
74
+ else:
75
+ cos, sin, cu_seqlens = ctx.saved_tensors
76
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
77
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
78
+ if not ctx.interleaved and not ctx.inplace:
79
+ do = do.clone()
80
+ dx = apply_rotary(
81
+ do,
82
+ cos,
83
+ sin,
84
+ seqlen_offsets=seqlen_offsets,
85
+ cu_seqlens=cu_seqlens,
86
+ max_seqlen=ctx.max_seqlen,
87
+ interleaved=ctx.interleaved,
88
+ inplace=ctx.inplace,
89
+ conjugate=True,
90
+ )
91
+ return dx, None, None, None, None, None, None, None
92
+
93
+
94
+ def apply_rotary_emb(
95
+ x,
96
+ cos,
97
+ sin,
98
+ interleaved=False,
99
+ inplace=False,
100
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
101
+ cu_seqlens: Optional[torch.Tensor] = None,
102
+ max_seqlen: Optional[int] = None,
103
+ ):
104
+ """
105
+ Arguments:
106
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
107
+ else (total_seqlen, nheads, headdim)
108
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
109
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
110
+ of 1st half and 2nd half (GPT-NeoX style).
111
+ inplace: if True, apply rotary embedding in-place.
112
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
113
+ Most commonly used in inference when we have KV cache.
114
+ cu_seqlens: (batch + 1,) or None
115
+ max_seqlen: int
116
+ Return:
117
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
118
+ else (total_seqlen, nheads, headdim)
119
+ rotary_dim must be <= headdim
120
+ Apply rotary embedding to the first rotary_dim of x.
121
+ """
122
+ return ApplyRotaryEmb.apply(
123
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
124
+ )
125
+
126
+
127
+ # For backward compatibility
128
+ apply_rotary_emb_func = apply_rotary_emb
129
+
130
+
131
+ class ApplyRotaryEmbQKV_(torch.autograd.Function):
132
+ @staticmethod
133
+ def forward(
134
+ ctx,
135
+ qkv,
136
+ cos,
137
+ sin,
138
+ cos_k=None,
139
+ sin_k=None,
140
+ interleaved=False,
141
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
142
+ cu_seqlens: Optional[torch.Tensor] = None,
143
+ max_seqlen: Optional[int] = None,
144
+ ):
145
+ # batch, seqlen, three, nheads, headdim = qkv.shape
146
+ assert qkv.shape[-3] == 3
147
+ if cos_k is None and sin_k is None and qkv.is_contiguous():
148
+ # Call 1 kernel instead of 2 kernels
149
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
150
+ # dimensions, we get the same tensor
151
+ qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
152
+ # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
153
+ apply_rotary(
154
+ qk,
155
+ cos,
156
+ sin,
157
+ seqlen_offsets=seqlen_offsets,
158
+ interleaved=interleaved,
159
+ inplace=True,
160
+ cu_seqlens=cu_seqlens,
161
+ max_seqlen=max_seqlen,
162
+ )
163
+ else:
164
+ cos_k = cos if cos_k is None else cos_k
165
+ sin_k = sin if sin_k is None else sin_k
166
+ q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
167
+ apply_rotary(
168
+ q,
169
+ cos,
170
+ sin,
171
+ seqlen_offsets,
172
+ interleaved=interleaved,
173
+ inplace=True,
174
+ cu_seqlens=cu_seqlens,
175
+ max_seqlen=max_seqlen,
176
+ )
177
+ apply_rotary(
178
+ k,
179
+ cos_k,
180
+ sin_k,
181
+ seqlen_offsets,
182
+ interleaved=interleaved,
183
+ inplace=True,
184
+ cu_seqlens=cu_seqlens,
185
+ max_seqlen=max_seqlen,
186
+ )
187
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
188
+ if isinstance(seqlen_offsets, int):
189
+ ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens)
190
+ ctx.seqlen_offsets = seqlen_offsets
191
+ else:
192
+ ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets)
193
+ ctx.seqlen_offsets = None
194
+ ctx.max_seqlen = max_seqlen
195
+ ctx.interleaved = interleaved
196
+ return qkv
197
+
198
+ @staticmethod
199
+ def backward(ctx, dqkv):
200
+ seqlen_offsets = ctx.seqlen_offsets
201
+ if seqlen_offsets is None:
202
+ cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
203
+ else:
204
+ cos, sin, cos_k, sin_k, cu_seqlens = ctx.saved_tensors
205
+ if cos_k is None and sin_k is None and dqkv.is_contiguous():
206
+ # Call 1 kernel instead of 2 kernels
207
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
208
+ # dimensions, we get the same tensor
209
+ dqk = rearrange(dqkv[..., :2, :, :], "... t h d -> ... (t h) d")
210
+ apply_rotary(
211
+ dqk,
212
+ cos,
213
+ sin,
214
+ seqlen_offsets=seqlen_offsets,
215
+ interleaved=ctx.interleaved,
216
+ inplace=True,
217
+ conjugate=True,
218
+ cu_seqlens=cu_seqlens,
219
+ max_seqlen=ctx.max_seqlen,
220
+ )
221
+ else:
222
+ cos_k = cos if cos_k is None else cos_k
223
+ sin_k = sin if sin_k is None else sin_k
224
+ dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
225
+ apply_rotary(
226
+
227
+ dq,
228
+ cos,
229
+ sin,
230
+ seqlen_offsets,
231
+ interleaved=ctx.interleaved,
232
+ inplace=True,
233
+ conjugate=True,
234
+ cu_seqlens=cu_seqlens,
235
+ max_seqlen=ctx.max_seqlen,
236
+ )
237
+ apply_rotary(
238
+ dk,
239
+ cos_k,
240
+ sin_k,
241
+ seqlen_offsets,
242
+ interleaved=ctx.interleaved,
243
+ inplace=True,
244
+ conjugate=True,
245
+ cu_seqlens=cu_seqlens,
246
+ max_seqlen=ctx.max_seqlen,
247
+ )
248
+ return dqkv, None, None, None, None, None, None, None, None
249
+
250
+
251
+ def apply_rotary_emb_qkv_(
252
+ qkv,
253
+ cos,
254
+ sin,
255
+ cos_k=None,
256
+ sin_k=None,
257
+ interleaved=False,
258
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
259
+ cu_seqlens: Optional[torch.Tensor] = None,
260
+ max_seqlen: Optional[int] = None,
261
+ ):
262
+ """
263
+ Arguments:
264
+ qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
265
+ else (total_seqlen, 3, nheads, headdim)
266
+ cos, sin: (seqlen, rotary_dim / 2)
267
+ cos_k, sin_k: (seqlen, rotary_dim / 2), optional
268
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
269
+ 1st half and 2nd half (GPT-NeoX style).
270
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
271
+ Most commonly used in inference when we have KV cache.
272
+ cu_seqlens: (batch + 1,) or None
273
+ max_seqlen: int
274
+ Return:
275
+ qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
276
+ else (total_seqlen, 3, nheads, headdim)
277
+ rotary_dim must be <= headdim
278
+ Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
279
+ """
280
+ return ApplyRotaryEmbQKV_.apply(
281
+ qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
282
+ )
283
+
284
+
285
+ class ApplyRotaryEmbKV_(torch.autograd.Function):
286
+ @staticmethod
287
+ def forward(
288
+ ctx,
289
+ kv,
290
+ cos,
291
+ sin,
292
+ interleaved=False,
293
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
294
+ cu_seqlens: Optional[torch.Tensor] = None,
295
+ max_seqlen: Optional[int] = None,
296
+ ):
297
+ # batch, seqlen, two, nheads, headdim = kv.shape
298
+ assert kv.shape[-3] == 2
299
+ k = kv[..., 0, :, :]
300
+ apply_rotary(
301
+ k,
302
+ cos,
303
+ sin,
304
+ seqlen_offsets=seqlen_offsets,
305
+ interleaved=interleaved,
306
+ inplace=True,
307
+ cu_seqlens=cu_seqlens,
308
+ max_seqlen=max_seqlen,
309
+ )
310
+ if isinstance(seqlen_offsets, int):
311
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
312
+ ctx.seqlen_offsets = seqlen_offsets
313
+ else:
314
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
315
+ ctx.seqlen_offsets = None
316
+ ctx.max_seqlen = max_seqlen
317
+ ctx.interleaved = interleaved
318
+ return kv
319
+
320
+ @staticmethod
321
+ def backward(ctx, dkv):
322
+ seqlen_offsets = ctx.seqlen_offsets
323
+ if seqlen_offsets is None:
324
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
325
+ else:
326
+ cos, sin, cu_seqlens = ctx.saved_tensors
327
+ apply_rotary(
328
+ dkv[..., 0, :, :],
329
+ cos,
330
+ sin,
331
+ seqlen_offsets=seqlen_offsets,
332
+ interleaved=ctx.interleaved,
333
+ inplace=True,
334
+ conjugate=True,
335
+ cu_seqlens=cu_seqlens,
336
+ max_seqlen=ctx.max_seqlen,
337
+ )
338
+ return dkv, None, None, None, None, None, None
339
+
340
+
341
+ apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
342
+
343
+
344
+ def apply_rotary_emb_kv_(
345
+ kv,
346
+ cos,
347
+ sin,
348
+ interleaved=False,
349
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
350
+ cu_seqlens: Optional[torch.Tensor] = None,
351
+ max_seqlen: Optional[int] = None,
352
+ ):
353
+ """
354
+ Arguments:
355
+ kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
356
+ else (total_seqlen, 2, nheads, headdim)
357
+ cos, sin: (seqlen, rotary_dim / 2)
358
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
359
+ 1st half and 2nd half (GPT-NeoX style).
360
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
361
+ Most commonly used in inference when we have KV cache.
362
+ cu_seqlens: (batch + 1,) or None
363
+ max_seqlen: int
364
+ Return:
365
+ kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
366
+ else (total_seqlen, 2, nheads, headdim)
367
+ rotary_dim must be <= headdim
368
+ Apply rotary embedding *inplace* to the first rotary_dim of K.
369
+ """
370
+ return ApplyRotaryEmbKV_.apply(
371
+ kv, cos, sin, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
372
+ )
373
+
374
+
375
+ class RotaryEmbedding(torch.nn.Module):
376
+ """
377
+ The rotary position embeddings from RoFormer_ (Su et. al).
378
+ A crucial insight from the method is that the query and keys are
379
+ transformed by rotation matrices which depend on the relative positions.
380
+
381
+ Other implementations are available in the Rotary Transformer repo_ and in
382
+ GPT-NeoX_, GPT-NeoX was an inspiration
383
+
384
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
385
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
386
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
387
+
388
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
389
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
390
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
391
+ """
392
+
393
+ def __init__(
394
+ self,
395
+ dim: int,
396
+ base=10000.0,
397
+ interleaved=False,
398
+ scale_base=None,
399
+ pos_idx_in_fp32=True,
400
+ device=None,
401
+ ):
402
+ """
403
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
404
+ of 1st half and 2nd half (GPT-NeoX style).
405
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
406
+ otherwise they might be in lower precision.
407
+ This option was added because previously (before 2023-07-02), when we construct
408
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
409
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
410
+ self.inv_freq would be bf16, and the position indices are also in bf16.
411
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
412
+ embeddings for some positions will coincide.
413
+ To maintain compatibility with models previously trained in pure bf16,
414
+ we add this option.
415
+ """
416
+ super().__init__()
417
+ self.dim = dim
418
+ self.base = float(base)
419
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
420
+ # Generate and save the inverse frequency buffer (non trainable)
421
+ inv_freq = self._compute_inv_freq(device)
422
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
423
+ self.interleaved = interleaved
424
+ self.scale_base = scale_base
425
+ scale = (
426
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
427
+ if scale_base is not None
428
+ else None
429
+ )
430
+ self.register_buffer("scale", scale, persistent=False)
431
+
432
+ self._seq_len_cached = 0
433
+ self._cos_cached = None
434
+ self._sin_cached = None
435
+ self._cos_k_cached = None
436
+ self._sin_k_cached = None
437
+
438
+ def _compute_inv_freq(self, device=None):
439
+ return 1.0 / (
440
+ self.base
441
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
442
+ )
443
+
444
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
445
+ # Reset the tables if the sequence length has changed,
446
+ # if we're on a new device (possibly due to tracing for instance),
447
+ # or if we're switching from inference mode to training
448
+ if (
449
+ seqlen > self._seq_len_cached
450
+ or self._cos_cached is None
451
+ or self._cos_cached.device != device
452
+ or self._cos_cached.dtype != dtype
453
+ or (self.training and self._cos_cached.is_inference())
454
+ ):
455
+ self._seq_len_cached = seqlen
456
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
457
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
458
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
459
+ if self.pos_idx_in_fp32:
460
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
461
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
462
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
463
+ # cos & sin output to change significantly.
464
+ # We want to recompute self.inv_freq if it was not loaded in fp32
465
+ if self.inv_freq.dtype != torch.float32:
466
+ inv_freq = self._compute_inv_freq(device=device)
467
+ else:
468
+ inv_freq = self.inv_freq
469
+ else:
470
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
471
+ inv_freq = self.inv_freq
472
+ # Don't do einsum, it converts fp32 to fp16 under AMP
473
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
474
+ freqs = torch.outer(t, inv_freq)
475
+ if self.scale is None:
476
+ self._cos_cached = torch.cos(freqs).to(dtype)
477
+ self._sin_cached = torch.sin(freqs).to(dtype)
478
+ else:
479
+ power = (
480
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
481
+ - seqlen // 2
482
+ ) / self.scale_base
483
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
484
+ # We want the multiplication by scale to happen in fp32
485
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
486
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
487
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
488
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
489
+
490
+ def forward(
491
+ self,
492
+ qkv: torch.Tensor,
493
+ kv: Optional[torch.Tensor] = None,
494
+ seqlen_offset: Union[int, torch.Tensor] = 0,
495
+ cu_seqlens: Optional[torch.Tensor] = None,
496
+ max_seqlen: Optional[int] = None,
497
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
498
+ """
499
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
500
+ else it's just q of shape (batch, seqlen, nheads, headdim)
501
+ kv: (batch, seqlen, 2, nheads, headdim)
502
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
503
+ Most commonly used in inference when we have KV cache.
504
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
505
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
506
+ Apply rotary embedding *inplace* to qkv and / or kv.
507
+ """
508
+ if cu_seqlens is not None:
509
+ assert max_seqlen is not None
510
+ seqlen = qkv.shape[1] if max_seqlen is None else max_seqlen
511
+ if max_seqlen is not None:
512
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
513
+ elif isinstance(seqlen_offset, int):
514
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
515
+ if kv is None:
516
+ if self.scale is None:
517
+ return apply_rotary_emb_qkv_(
518
+ qkv,
519
+ self._cos_cached,
520
+ self._sin_cached,
521
+ interleaved=self.interleaved,
522
+ seqlen_offsets=seqlen_offset,
523
+ cu_seqlens=cu_seqlens,
524
+ max_seqlen=max_seqlen,
525
+ )
526
+ else:
527
+ return apply_rotary_emb_qkv_(
528
+ qkv,
529
+ self._cos_cached,
530
+ self._sin_cached,
531
+ self._cos_k_cached,
532
+ self._sin_k_cached,
533
+ interleaved=self.interleaved,
534
+ seqlen_offsets=seqlen_offset,
535
+ cu_seqlens=cu_seqlens,
536
+ max_seqlen=max_seqlen,
537
+ )
538
+ else:
539
+ q = qkv
540
+ q = apply_rotary_emb_func(
541
+ q,
542
+ self._cos_cached,
543
+ self._sin_cached,
544
+ interleaved=self.interleaved,
545
+ inplace=True,
546
+ seqlen_offsets=seqlen_offset,
547
+ cu_seqlens=cu_seqlens,
548
+ max_seqlen=max_seqlen,
549
+ )
550
+ if self.scale is None:
551
+ kv = apply_rotary_emb_kv_(
552
+ kv,
553
+ self._cos_cached,
554
+ self._sin_cached,
555
+ interleaved=self.interleaved,
556
+ seqlen_offsets=seqlen_offset,
557
+ cu_seqlens=cu_seqlens,
558
+ max_seqlen=max_seqlen,
559
+ )
560
+ else:
561
+ kv = apply_rotary_emb_kv_(
562
+ kv,
563
+ self._cos_k_cached,
564
+ self._sin_k_cached,
565
+ interleaved=self.interleaved,
566
+ seqlen_offsets=seqlen_offset,
567
+ cu_seqlens=cu_seqlens,
568
+ max_seqlen=max_seqlen,
569
+ )
570
+ return q, kv