Files changed (3) hide show
  1. mha.py +14 -12
  2. modeling_xlm_roberta.py +2 -2
  3. rotary.py +575 -0
mha.py CHANGED
@@ -1,7 +1,5 @@
1
- # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
2
- # Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
3
-
4
  # Copyright (c) 2023, Tri Dao.
 
5
 
6
  import math
7
  from functools import partial
@@ -28,10 +26,7 @@ try:
28
  except ImportError:
29
  FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
30
 
31
- try:
32
- from flash_attn.layers.rotary import RotaryEmbedding
33
- except ImportError:
34
- RotaryEmbedding = None
35
 
36
 
37
  # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
@@ -619,7 +614,6 @@ class MHA(nn.Module):
619
  assert key_padding_mask is None
620
  assert self.use_flash_attn
621
  assert not self.dwconv
622
- assert self.rotary_emb_dim == 0
623
  if key_padding_mask is not None:
624
  assert cu_seqlens is None
625
  assert max_seqlen is None
@@ -643,7 +637,9 @@ class MHA(nn.Module):
643
  else inference_params.seqlen_offset
644
  )
645
  )
646
- rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
 
 
647
  batch, seqlen = x.shape[:2]
648
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
649
  assert x_kv is None and mixer_subset is None
@@ -664,7 +660,10 @@ class MHA(nn.Module):
664
  ):
665
  if self.rotary_emb_dim > 0:
666
  qkv = self.rotary_emb(
667
- qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
 
 
 
668
  )
669
  if inference_params is None:
670
  if not self.checkpointing:
@@ -715,7 +714,11 @@ class MHA(nn.Module):
715
  ):
716
  if self.rotary_emb_dim > 0:
717
  q, kv = self.rotary_emb(
718
- q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
 
 
 
 
719
  )
720
  if inference_params is None:
721
  if not self.checkpointing:
@@ -730,4 +733,3 @@ class MHA(nn.Module):
730
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
731
  out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
732
  return out if not self.return_residual else (out, x)
733
-
 
 
 
 
1
  # Copyright (c) 2023, Tri Dao.
2
+ # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
3
 
4
  import math
5
  from functools import partial
 
26
  except ImportError:
27
  FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
28
 
29
+ from .rotary import RotaryEmbedding
 
 
 
30
 
31
 
32
  # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
 
614
  assert key_padding_mask is None
615
  assert self.use_flash_attn
616
  assert not self.dwconv
 
617
  if key_padding_mask is not None:
618
  assert cu_seqlens is None
619
  assert max_seqlen is None
 
637
  else inference_params.seqlen_offset
638
  )
639
  )
640
+ rotary_max_seqlen = (
641
+ inference_params.max_sequence_len if inference_params is not None else max_seqlen
642
+ )
643
  batch, seqlen = x.shape[:2]
644
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
645
  assert x_kv is None and mixer_subset is None
 
660
  ):
661
  if self.rotary_emb_dim > 0:
662
  qkv = self.rotary_emb(
663
+ qkv,
664
+ seqlen_offset=seqlen_offset,
665
+ cu_seqlens=cu_seqlens,
666
+ max_seqlen=rotary_max_seqlen,
667
  )
668
  if inference_params is None:
669
  if not self.checkpointing:
 
714
  ):
715
  if self.rotary_emb_dim > 0:
716
  q, kv = self.rotary_emb(
717
+ q,
718
+ kv,
719
+ seqlen_offset=seqlen_offset,
720
+ cu_seqlens=cu_seqlens,
721
+ max_seqlen=rotary_max_seqlen,
722
  )
723
  if inference_params is None:
724
  if not self.checkpointing:
 
733
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
734
  out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
735
  return out if not self.return_residual else (out, x)
 
modeling_xlm_roberta.py CHANGED
@@ -45,7 +45,7 @@ from .embedding import XLMRobertaEmbeddings
45
  from .mha import MHA
46
  from .mlp import FusedMLP, Mlp
47
  from .stochastic_depth import StochasticDepth
48
-
49
 
50
  try:
51
  from flash_attn.ops.fused_dense import FusedDense
@@ -91,7 +91,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
91
  rotary_kwargs = {}
92
  if config.position_embedding_type == "rotary":
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
- config, "rotary_emb_dim", config.hidden_size
95
  )
96
  rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
 
45
  from .mha import MHA
46
  from .mlp import FusedMLP, Mlp
47
  from .stochastic_depth import StochasticDepth
48
+ from .rotary import RotaryEmbedding
49
 
50
  try:
51
  from flash_attn.ops.fused_dense import FusedDense
 
91
  rotary_kwargs = {}
92
  if config.position_embedding_type == "rotary":
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
+ config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
95
  )
96
  rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
rotary.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
2
+ # Copyright (c) 2023, Tri Dao.
3
+
4
+ import math
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+ from einops import rearrange, repeat
9
+ try:
10
+ from flash_attn.ops.triton.rotary import apply_rotary
11
+ except ImportError:
12
+ def apply_rotary(*args, **kwargs):
13
+ raise RuntimeError('RoPE requires flash-attention to be installed')
14
+
15
+
16
+ def rotate_half(x, interleaved=False):
17
+ if not interleaved:
18
+ x1, x2 = x.chunk(2, dim=-1)
19
+ return torch.cat((-x2, x1), dim=-1)
20
+ else:
21
+ x1, x2 = x[..., ::2], x[..., 1::2]
22
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
23
+
24
+
25
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
26
+ """
27
+ x: (batch_size, seqlen, nheads, headdim)
28
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
29
+ """
30
+ ro_dim = cos.shape[-1] * 2
31
+ assert ro_dim <= x.shape[-1]
32
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
33
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
34
+ return torch.cat(
35
+ [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
36
+ dim=-1,
37
+ )
38
+
39
+
40
+ class ApplyRotaryEmb(torch.autograd.Function):
41
+ @staticmethod
42
+ def forward(
43
+ ctx,
44
+ x,
45
+ cos,
46
+ sin,
47
+ interleaved=False,
48
+ inplace=False,
49
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
50
+ cu_seqlens: Optional[torch.Tensor] = None,
51
+ max_seqlen: Optional[int] = None,
52
+ ):
53
+ out = apply_rotary(
54
+ x,
55
+ cos,
56
+ sin,
57
+ seqlen_offsets=seqlen_offsets,
58
+ cu_seqlens=cu_seqlens,
59
+ max_seqlen=max_seqlen,
60
+ interleaved=interleaved,
61
+ inplace=inplace,
62
+ )
63
+ if isinstance(seqlen_offsets, int):
64
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
65
+ ctx.seqlen_offsets = seqlen_offsets
66
+ else:
67
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
68
+ ctx.seqlen_offsets = None
69
+ ctx.interleaved = interleaved
70
+ ctx.inplace = inplace
71
+ ctx.max_seqlen = max_seqlen
72
+ return out if not inplace else x
73
+
74
+ @staticmethod
75
+ def backward(ctx, do):
76
+ seqlen_offsets = ctx.seqlen_offsets
77
+ if seqlen_offsets is None:
78
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
79
+ else:
80
+ cos, sin, cu_seqlens = ctx.saved_tensors
81
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
82
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
83
+ if not ctx.interleaved and not ctx.inplace:
84
+ do = do.clone()
85
+ dx = apply_rotary(
86
+ do,
87
+ cos,
88
+ sin,
89
+ seqlen_offsets=seqlen_offsets,
90
+ cu_seqlens=cu_seqlens,
91
+ max_seqlen=ctx.max_seqlen,
92
+ interleaved=ctx.interleaved,
93
+ inplace=ctx.inplace,
94
+ conjugate=True,
95
+ )
96
+ return dx, None, None, None, None, None, None, None
97
+
98
+
99
+ def apply_rotary_emb(
100
+ x,
101
+ cos,
102
+ sin,
103
+ interleaved=False,
104
+ inplace=False,
105
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
106
+ cu_seqlens: Optional[torch.Tensor] = None,
107
+ max_seqlen: Optional[int] = None,
108
+ ):
109
+ """
110
+ Arguments:
111
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
112
+ else (total_seqlen, nheads, headdim)
113
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
114
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
115
+ of 1st half and 2nd half (GPT-NeoX style).
116
+ inplace: if True, apply rotary embedding in-place.
117
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
118
+ Most commonly used in inference when we have KV cache.
119
+ cu_seqlens: (batch + 1,) or None
120
+ max_seqlen: int
121
+ Return:
122
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
123
+ else (total_seqlen, nheads, headdim)
124
+ rotary_dim must be <= headdim
125
+ Apply rotary embedding to the first rotary_dim of x.
126
+ """
127
+ return ApplyRotaryEmb.apply(
128
+ x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
129
+ )
130
+
131
+
132
+ # For backward compatibility
133
+ apply_rotary_emb_func = apply_rotary_emb
134
+
135
+
136
+ class ApplyRotaryEmbQKV_(torch.autograd.Function):
137
+ @staticmethod
138
+ def forward(
139
+ ctx,
140
+ qkv,
141
+ cos,
142
+ sin,
143
+ cos_k=None,
144
+ sin_k=None,
145
+ interleaved=False,
146
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
147
+ cu_seqlens: Optional[torch.Tensor] = None,
148
+ max_seqlen: Optional[int] = None,
149
+ ):
150
+ # batch, seqlen, three, nheads, headdim = qkv.shape
151
+ assert qkv.shape[-3] == 3
152
+ if cos_k is None and sin_k is None and qkv.is_contiguous():
153
+ # Call 1 kernel instead of 2 kernels
154
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
155
+ # dimensions, we get the same tensor
156
+ qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
157
+ # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
158
+ apply_rotary(
159
+ qk,
160
+ cos,
161
+ sin,
162
+ seqlen_offsets=seqlen_offsets,
163
+ interleaved=interleaved,
164
+ inplace=True,
165
+ cu_seqlens=cu_seqlens,
166
+ max_seqlen=max_seqlen,
167
+ )
168
+ else:
169
+ cos_k = cos if cos_k is None else cos_k
170
+ sin_k = sin if sin_k is None else sin_k
171
+ q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
172
+ apply_rotary(
173
+ q,
174
+ cos,
175
+ sin,
176
+ seqlen_offsets,
177
+ interleaved=interleaved,
178
+ inplace=True,
179
+ cu_seqlens=cu_seqlens,
180
+ max_seqlen=max_seqlen,
181
+ )
182
+ apply_rotary(
183
+ k,
184
+ cos_k,
185
+ sin_k,
186
+ seqlen_offsets,
187
+ interleaved=interleaved,
188
+ inplace=True,
189
+ cu_seqlens=cu_seqlens,
190
+ max_seqlen=max_seqlen,
191
+ )
192
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
193
+ if isinstance(seqlen_offsets, int):
194
+ ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens)
195
+ ctx.seqlen_offsets = seqlen_offsets
196
+ else:
197
+ ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets)
198
+ ctx.seqlen_offsets = None
199
+ ctx.max_seqlen = max_seqlen
200
+ ctx.interleaved = interleaved
201
+ return qkv
202
+
203
+ @staticmethod
204
+ def backward(ctx, dqkv):
205
+ seqlen_offsets = ctx.seqlen_offsets
206
+ if seqlen_offsets is None:
207
+ cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
208
+ else:
209
+ cos, sin, cos_k, sin_k, cu_seqlens = ctx.saved_tensors
210
+ if cos_k is None and sin_k is None and dqkv.is_contiguous():
211
+ # Call 1 kernel instead of 2 kernels
212
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
213
+ # dimensions, we get the same tensor
214
+ dqk = rearrange(dqkv[..., :2, :, :], "... t h d -> ... (t h) d")
215
+ apply_rotary(
216
+ dqk,
217
+ cos,
218
+ sin,
219
+ seqlen_offsets=seqlen_offsets,
220
+ interleaved=ctx.interleaved,
221
+ inplace=True,
222
+ conjugate=True,
223
+ cu_seqlens=cu_seqlens,
224
+ max_seqlen=ctx.max_seqlen,
225
+ )
226
+ else:
227
+ cos_k = cos if cos_k is None else cos_k
228
+ sin_k = sin if sin_k is None else sin_k
229
+ dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
230
+ apply_rotary(
231
+
232
+ dq,
233
+ cos,
234
+ sin,
235
+ seqlen_offsets,
236
+ interleaved=ctx.interleaved,
237
+ inplace=True,
238
+ conjugate=True,
239
+ cu_seqlens=cu_seqlens,
240
+ max_seqlen=ctx.max_seqlen,
241
+ )
242
+ apply_rotary(
243
+ dk,
244
+ cos_k,
245
+ sin_k,
246
+ seqlen_offsets,
247
+ interleaved=ctx.interleaved,
248
+ inplace=True,
249
+ conjugate=True,
250
+ cu_seqlens=cu_seqlens,
251
+ max_seqlen=ctx.max_seqlen,
252
+ )
253
+ return dqkv, None, None, None, None, None, None, None, None
254
+
255
+
256
+ def apply_rotary_emb_qkv_(
257
+ qkv,
258
+ cos,
259
+ sin,
260
+ cos_k=None,
261
+ sin_k=None,
262
+ interleaved=False,
263
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
264
+ cu_seqlens: Optional[torch.Tensor] = None,
265
+ max_seqlen: Optional[int] = None,
266
+ ):
267
+ """
268
+ Arguments:
269
+ qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
270
+ else (total_seqlen, 3, nheads, headdim)
271
+ cos, sin: (seqlen, rotary_dim / 2)
272
+ cos_k, sin_k: (seqlen, rotary_dim / 2), optional
273
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
274
+ 1st half and 2nd half (GPT-NeoX style).
275
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
276
+ Most commonly used in inference when we have KV cache.
277
+ cu_seqlens: (batch + 1,) or None
278
+ max_seqlen: int
279
+ Return:
280
+ qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
281
+ else (total_seqlen, 3, nheads, headdim)
282
+ rotary_dim must be <= headdim
283
+ Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
284
+ """
285
+ return ApplyRotaryEmbQKV_.apply(
286
+ qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
287
+ )
288
+
289
+
290
+ class ApplyRotaryEmbKV_(torch.autograd.Function):
291
+ @staticmethod
292
+ def forward(
293
+ ctx,
294
+ kv,
295
+ cos,
296
+ sin,
297
+ interleaved=False,
298
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
299
+ cu_seqlens: Optional[torch.Tensor] = None,
300
+ max_seqlen: Optional[int] = None,
301
+ ):
302
+ # batch, seqlen, two, nheads, headdim = kv.shape
303
+ assert kv.shape[-3] == 2
304
+ k = kv[..., 0, :, :]
305
+ apply_rotary(
306
+ k,
307
+ cos,
308
+ sin,
309
+ seqlen_offsets=seqlen_offsets,
310
+ interleaved=interleaved,
311
+ inplace=True,
312
+ cu_seqlens=cu_seqlens,
313
+ max_seqlen=max_seqlen,
314
+ )
315
+ if isinstance(seqlen_offsets, int):
316
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
317
+ ctx.seqlen_offsets = seqlen_offsets
318
+ else:
319
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
320
+ ctx.seqlen_offsets = None
321
+ ctx.max_seqlen = max_seqlen
322
+ ctx.interleaved = interleaved
323
+ return kv
324
+
325
+ @staticmethod
326
+ def backward(ctx, dkv):
327
+ seqlen_offsets = ctx.seqlen_offsets
328
+ if seqlen_offsets is None:
329
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
330
+ else:
331
+ cos, sin, cu_seqlens = ctx.saved_tensors
332
+ apply_rotary(
333
+ dkv[..., 0, :, :],
334
+ cos,
335
+ sin,
336
+ seqlen_offsets=seqlen_offsets,
337
+ interleaved=ctx.interleaved,
338
+ inplace=True,
339
+ conjugate=True,
340
+ cu_seqlens=cu_seqlens,
341
+ max_seqlen=ctx.max_seqlen,
342
+ )
343
+ return dkv, None, None, None, None, None, None
344
+
345
+
346
+ apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
347
+
348
+
349
+ def apply_rotary_emb_kv_(
350
+ kv,
351
+ cos,
352
+ sin,
353
+ interleaved=False,
354
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
355
+ cu_seqlens: Optional[torch.Tensor] = None,
356
+ max_seqlen: Optional[int] = None,
357
+ ):
358
+ """
359
+ Arguments:
360
+ kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
361
+ else (total_seqlen, 2, nheads, headdim)
362
+ cos, sin: (seqlen, rotary_dim / 2)
363
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
364
+ 1st half and 2nd half (GPT-NeoX style).
365
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
366
+ Most commonly used in inference when we have KV cache.
367
+ cu_seqlens: (batch + 1,) or None
368
+ max_seqlen: int
369
+ Return:
370
+ kv: (batch_size, seqlen, 2, nheads, headdim) if cu_seqlens is None
371
+ else (total_seqlen, 2, nheads, headdim)
372
+ rotary_dim must be <= headdim
373
+ Apply rotary embedding *inplace* to the first rotary_dim of K.
374
+ """
375
+ return ApplyRotaryEmbKV_.apply(
376
+ kv, cos, sin, interleaved, seqlen_offsets, cu_seqlens, max_seqlen
377
+ )
378
+
379
+
380
+ class RotaryEmbedding(torch.nn.Module):
381
+ """
382
+ The rotary position embeddings from RoFormer_ (Su et. al).
383
+ A crucial insight from the method is that the query and keys are
384
+ transformed by rotation matrices which depend on the relative positions.
385
+
386
+ Other implementations are available in the Rotary Transformer repo_ and in
387
+ GPT-NeoX_, GPT-NeoX was an inspiration
388
+
389
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
390
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
391
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
392
+
393
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
394
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
395
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
396
+ """
397
+
398
+ def __init__(
399
+ self,
400
+ dim: int,
401
+ base=10000.0,
402
+ interleaved=False,
403
+ scale_base=None,
404
+ pos_idx_in_fp32=True,
405
+ device=None,
406
+ ):
407
+ """
408
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
409
+ of 1st half and 2nd half (GPT-NeoX style).
410
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
411
+ otherwise they might be in lower precision.
412
+ This option was added because previously (before 2023-07-02), when we construct
413
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
414
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
415
+ self.inv_freq would be bf16, and the position indices are also in bf16.
416
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
417
+ embeddings for some positions will coincide.
418
+ To maintain compatibility with models previously trained in pure bf16,
419
+ we add this option.
420
+ """
421
+ super().__init__()
422
+ self.dim = dim
423
+ self.base = float(base)
424
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
425
+ # Generate and save the inverse frequency buffer (non trainable)
426
+ inv_freq = self._compute_inv_freq(device)
427
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
428
+ self.interleaved = interleaved
429
+ self.scale_base = scale_base
430
+ scale = (
431
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
432
+ if scale_base is not None
433
+ else None
434
+ )
435
+ self.register_buffer("scale", scale, persistent=False)
436
+
437
+ self._seq_len_cached = 0
438
+ self._cos_cached = None
439
+ self._sin_cached = None
440
+ self._cos_k_cached = None
441
+ self._sin_k_cached = None
442
+
443
+ def _compute_inv_freq(self, device=None):
444
+ return 1.0 / (
445
+ self.base
446
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
447
+ )
448
+
449
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
450
+ # Reset the tables if the sequence length has changed,
451
+ # if we're on a new device (possibly due to tracing for instance),
452
+ # or if we're switching from inference mode to training
453
+ if (
454
+ seqlen > self._seq_len_cached
455
+ or self._cos_cached is None
456
+ or self._cos_cached.device != device
457
+ or self._cos_cached.dtype != dtype
458
+ or (self.training and self._cos_cached.is_inference())
459
+ ):
460
+ self._seq_len_cached = seqlen
461
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
462
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
463
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
464
+ if self.pos_idx_in_fp32:
465
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
466
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
467
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
468
+ # cos & sin output to change significantly.
469
+ # We want to recompute self.inv_freq if it was not loaded in fp32
470
+ if self.inv_freq.dtype != torch.float32:
471
+ inv_freq = self._compute_inv_freq(device=device)
472
+ else:
473
+ inv_freq = self.inv_freq
474
+ else:
475
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
476
+ inv_freq = self.inv_freq
477
+ # Don't do einsum, it converts fp32 to fp16 under AMP
478
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
479
+ freqs = torch.outer(t, inv_freq)
480
+ if self.scale is None:
481
+ self._cos_cached = torch.cos(freqs).to(dtype)
482
+ self._sin_cached = torch.sin(freqs).to(dtype)
483
+ else:
484
+ power = (
485
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
486
+ - seqlen // 2
487
+ ) / self.scale_base
488
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
489
+ # We want the multiplication by scale to happen in fp32
490
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
491
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
492
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
493
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
494
+
495
+ def forward(
496
+ self,
497
+ qkv: torch.Tensor,
498
+ kv: Optional[torch.Tensor] = None,
499
+ seqlen_offset: Union[int, torch.Tensor] = 0,
500
+ cu_seqlens: Optional[torch.Tensor] = None,
501
+ max_seqlen: Optional[int] = None,
502
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
503
+ """
504
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
505
+ else it's just q of shape (batch, seqlen, nheads, headdim)
506
+ kv: (batch, seqlen, 2, nheads, headdim)
507
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
508
+ Most commonly used in inference when we have KV cache.
509
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
510
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
511
+ Apply rotary embedding *inplace* to qkv and / or kv.
512
+ """
513
+ if cu_seqlens is not None:
514
+ assert max_seqlen is not None
515
+ seqlen = qkv.shape[1] if max_seqlen is None else max_seqlen
516
+ if max_seqlen is not None:
517
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
518
+ elif isinstance(seqlen_offset, int):
519
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
520
+ if kv is None:
521
+ if self.scale is None:
522
+ return apply_rotary_emb_qkv_(
523
+ qkv,
524
+ self._cos_cached,
525
+ self._sin_cached,
526
+ interleaved=self.interleaved,
527
+ seqlen_offsets=seqlen_offset,
528
+ cu_seqlens=cu_seqlens,
529
+ max_seqlen=max_seqlen,
530
+ )
531
+ else:
532
+ return apply_rotary_emb_qkv_(
533
+ qkv,
534
+ self._cos_cached,
535
+ self._sin_cached,
536
+ self._cos_k_cached,
537
+ self._sin_k_cached,
538
+ interleaved=self.interleaved,
539
+ seqlen_offsets=seqlen_offset,
540
+ cu_seqlens=cu_seqlens,
541
+ max_seqlen=max_seqlen,
542
+ )
543
+ else:
544
+ q = qkv
545
+ q = apply_rotary_emb_func(
546
+ q,
547
+ self._cos_cached,
548
+ self._sin_cached,
549
+ interleaved=self.interleaved,
550
+ inplace=True,
551
+ seqlen_offsets=seqlen_offset,
552
+ cu_seqlens=cu_seqlens,
553
+ max_seqlen=max_seqlen,
554
+ )
555
+ if self.scale is None:
556
+ kv = apply_rotary_emb_kv_(
557
+ kv,
558
+ self._cos_cached,
559
+ self._sin_cached,
560
+ interleaved=self.interleaved,
561
+ seqlen_offsets=seqlen_offset,
562
+ cu_seqlens=cu_seqlens,
563
+ max_seqlen=max_seqlen,
564
+ )
565
+ else:
566
+ kv = apply_rotary_emb_kv_(
567
+ kv,
568
+ self._cos_k_cached,
569
+ self._sin_k_cached,
570
+ interleaved=self.interleaved,
571
+ seqlen_offsets=seqlen_offset,
572
+ cu_seqlens=cu_seqlens,
573
+ max_seqlen=max_seqlen,
574
+ )
575
+ return q, kv