jupyterjazz commited on
Commit
493416f
2 Parent(s): d9d8306 6a92924

feat: merge with recent changes

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (8) hide show
  1. block.py +1 -1
  2. configuration_xlm_roberta.py +2 -0
  3. embedding.py +2 -2
  4. mha.py +2 -2
  5. mlp.py +2 -2
  6. modeling_lora.py +23 -19
  7. modeling_xlm_roberta.py +15 -21
  8. rotary.py +44 -21
block.py CHANGED
@@ -233,7 +233,7 @@ class Block(nn.Module):
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
- mlp_out = self.mlp(hidden_states, task=mixer_kwargs.get('task'))
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
 
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
+ mlp_out = self.mlp(hidden_states, task_type=mixer_kwargs.get('task_type'))
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
configuration_xlm_roberta.py CHANGED
@@ -23,6 +23,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
23
  use_cache=True,
24
  classifier_dropout=None,
25
  lora_adaptations=None,
 
26
  lora_rank=4,
27
  lora_dropout_p=0.0,
28
  lora_alpha=1,
@@ -55,6 +56,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
55
  self.classifier_dropout = classifier_dropout
56
  self.load_trained_adapters = load_trained_adapters
57
  self.lora_adaptations = lora_adaptations
 
58
  self.lora_rank = lora_rank
59
  self.lora_dropout_p = lora_dropout_p
60
  self.lora_alpha = lora_alpha
 
23
  use_cache=True,
24
  classifier_dropout=None,
25
  lora_adaptations=None,
26
+ lora_prompts=None,
27
  lora_rank=4,
28
  lora_dropout_p=0.0,
29
  lora_alpha=1,
 
56
  self.classifier_dropout = classifier_dropout
57
  self.load_trained_adapters = load_trained_adapters
58
  self.lora_adaptations = lora_adaptations
59
+ self.lora_prompts = lora_prompts
60
  self.lora_rank = lora_rank
61
  self.lora_dropout_p = lora_dropout_p
62
  self.lora_alpha = lora_alpha
embedding.py CHANGED
@@ -40,14 +40,14 @@ class XLMRobertaEmbeddings(nn.Module):
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
- def forward(self, input_ids, position_ids=None, token_type_ids=None, task=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
- lora_kwargs = {'task': task} if task is not None else {}
51
  embeddings = self.word_embeddings(input_ids, **lora_kwargs)
52
  if self.max_position_embeddings > 0:
53
  if position_ids is None:
 
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
51
  embeddings = self.word_embeddings(input_ids, **lora_kwargs)
52
  if self.max_position_embeddings > 0:
53
  if position_ids is None:
mha.py CHANGED
@@ -590,7 +590,7 @@ class MHA(nn.Module):
590
  max_seqlen=None,
591
  mixer_subset=None,
592
  inference_params=None,
593
- task=None,
594
  **kwargs,
595
  ):
596
  """
@@ -645,7 +645,7 @@ class MHA(nn.Module):
645
  batch, seqlen = x.shape[:2]
646
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
647
  assert x_kv is None and mixer_subset is None
648
- lora_kwargs = {'task': task} if task is not None else {}
649
  if not self.return_residual:
650
  qkv = self.Wqkv(x, **lora_kwargs)
651
  else:
 
590
  max_seqlen=None,
591
  mixer_subset=None,
592
  inference_params=None,
593
+ task_type=None,
594
  **kwargs,
595
  ):
596
  """
 
645
  batch, seqlen = x.shape[:2]
646
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
647
  assert x_kv is None and mixer_subset is None
648
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
649
  if not self.return_residual:
650
  qkv = self.Wqkv(x, **lora_kwargs)
651
  else:
mlp.py CHANGED
@@ -47,8 +47,8 @@ class Mlp(nn.Module):
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
- def forward(self, x, task):
51
- lora_kwargs = {'task': task} if task is not None else {}
52
  y = self.fc1(x, **lora_kwargs)
53
  y = self.activation(y)
54
  y = self.fc2(y, **lora_kwargs)
 
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
+ def forward(self, x, task_type=None):
51
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
52
  y = self.fc1(x, **lora_kwargs)
53
  y = self.activation(y)
54
  y = self.fc2(y, **lora_kwargs)
modeling_lora.py CHANGED
@@ -15,9 +15,6 @@ from transformers import PretrainedConfig
15
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
16
 
17
 
18
- LORA_NO_UPDATE = '__lora_no_update__'
19
-
20
-
21
  def initialized_weights(
22
  shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
23
  ) -> torch.Tensor:
@@ -179,8 +176,8 @@ class LoRAParametrization(nn.Module):
179
  ),
180
  )
181
 
182
- def new_forward(self, input, task, residual=False):
183
- task_idx = adaptation_map[task] if task else None
184
  if task_idx is not None:
185
  weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
186
  else:
@@ -207,8 +204,8 @@ class LoRAParametrization(nn.Module):
207
  ),
208
  )
209
 
210
- def new_forward(self, input, task):
211
- task_idx = adaptation_map[task] if task else None
212
  if task_idx is not None:
213
  weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
214
  else:
@@ -244,6 +241,16 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
244
  raise ValueError(
245
  f'`lora_adaptations` must be a list and contain at least one element'
246
  )
 
 
 
 
 
 
 
 
 
 
247
  self._adaptation_map = {
248
  name: idx for idx, name in enumerate(self._lora_adaptations)
249
  }
@@ -335,25 +342,22 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
335
  def encode(
336
  self,
337
  *args,
338
- task: Optional[str] = None,
339
  **kwargs,
340
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
341
  """
342
  Computes sentence embeddings
343
 
344
- task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
345
- Specifies the task for which the encoding is intended. This parameter controls the
346
- use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
347
- to `LORA_NO_UPDATE`, there will be no update to the current task, retaining the
348
- existing adapter configuration. If `task` is explicitly set to `None`, all LoRA
349
- adapters are disabled, and the model reverts to its original, general-purpose weights.
350
- If `task` is set to a specific LoRA adaptation, that adaptation is activated.
351
  """
352
- if task and task not in self._lora_adaptations:
353
  raise ValueError(
354
- f"Unsupported task '{task}'. "
355
  f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
356
- f"Alternatively, don't pass the `task` argument to disable LoRA."
357
  )
358
 
359
- return self.roberta.encode(*args, **kwargs)
 
15
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
16
 
17
 
 
 
 
18
  def initialized_weights(
19
  shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
20
  ) -> torch.Tensor:
 
176
  ),
177
  )
178
 
179
+ def new_forward(self, input, task_type, residual=False):
180
+ task_idx = adaptation_map[task_type] if task_type else None
181
  if task_idx is not None:
182
  weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
183
  else:
 
204
  ),
205
  )
206
 
207
+ def new_forward(self, input, task_type):
208
+ task_idx = adaptation_map[task_type] if task_type else None
209
  if task_idx is not None:
210
  weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
211
  else:
 
241
  raise ValueError(
242
  f'`lora_adaptations` must be a list and contain at least one element'
243
  )
244
+ self._lora_prompts = config.lora_prompts
245
+ if (
246
+ not isinstance(self._lora_prompts, dict)
247
+ or len(self._lora_prompts) != len(self._lora_adaptations)
248
+ or not all([v in self._lora_adaptations for v in self._lora_prompts.keys()])
249
+ ):
250
+ raise ValueError(
251
+ f'`lora_prompts` must be a dict and contain the same number of elements '
252
+ f'as `lora_adaptations` with all keys in `lora_prompts` present in `lora_adaptations`.'
253
+ )
254
  self._adaptation_map = {
255
  name: idx for idx, name in enumerate(self._lora_adaptations)
256
  }
 
342
  def encode(
343
  self,
344
  *args,
345
+ task_type: Optional[str] = None,
346
  **kwargs,
347
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
348
  """
349
  Computes sentence embeddings
350
 
351
+ task_type(`str`, *optional*, defaults to `None`):
352
+ Specifies the task for which the encoding is intended. If `task_type` is not provide,
353
+ all LoRA adapters are disabled, and the model reverts to its original,
354
+ general-purpose weights.
 
 
 
355
  """
356
+ if task_type and task_type not in self._lora_adaptations:
357
  raise ValueError(
358
+ f"Unsupported task '{task_type}'. "
359
  f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
360
+ f"Alternatively, don't pass the `task_type` argument to disable LoRA."
361
  )
362
 
363
+ return self.roberta.encode(*args, task_type=task_type, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -21,7 +21,7 @@ import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
  from einops import rearrange
24
- from transformers import PretrainedConfig
25
  from transformers.modeling_utils import PreTrainedModel
26
  from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
@@ -204,7 +204,7 @@ class XLMRobertaEncoder(nn.Module):
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
- def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
@@ -215,7 +215,7 @@ class XLMRobertaEncoder(nn.Module):
215
  if key_padding_mask is not None
216
  else None
217
  )
218
- mixer_kwargs['task'] = task
219
  for layer in self.layers:
220
  if self._grad_checkpointing:
221
  hidden_states = torch.utils.checkpoint.checkpoint(
@@ -233,7 +233,7 @@ class XLMRobertaEncoder(nn.Module):
233
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
234
  hidden_states, key_padding_mask
235
  )
236
- mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task": task}
237
  if subset_mask is None:
238
  for layer in self.layers:
239
  if self._grad_checkpointing:
@@ -310,10 +310,10 @@ class XLMRobertaPooler(nn.Module):
310
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
311
  self.activation = nn.Tanh()
312
 
313
- def forward(self, hidden_states, pool=True, task=None):
314
  # We "pool" the model by simply taking the hidden state corresponding
315
  # to the first token.
316
- lora_kwargs = {'task': task} if task is not None else {}
317
 
318
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
319
  pooled_output = self.dense(first_token_tensor, **lora_kwargs)
@@ -443,7 +443,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
443
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
444
 
445
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
446
-
447
 
448
  @torch.inference_mode()
449
  def encode(
@@ -457,7 +457,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
457
  device: Optional[torch.device] = None,
458
  normalize_embeddings: bool = False,
459
  truncate_dim: Optional[int] = None,
460
- task: Optional[str] = None,
461
  **tokenizer_kwargs,
462
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
463
  """
@@ -496,12 +496,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
496
  If convert_to_tensor, a stacked tensor is returned.
497
  If convert_to_numpy, a numpy matrix is returned.
498
  """
499
- from transformers import AutoTokenizer
500
-
501
- self.tokenizer = AutoTokenizer.from_pretrained(
502
- self.name_or_path, trust_remote_code=True
503
- )
504
-
505
  is_training = self.training
506
  self.eval()
507
 
@@ -548,7 +542,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
548
  )
549
  else:
550
  range_iter = range(0, len(sentences), batch_size)
551
- lora_kwargs = {'task': task} if task is not None else {}
552
  for i in range_iter:
553
  encoded_input = self.tokenizer(
554
  sentences[i : i + batch_size],
@@ -643,7 +637,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
643
  layer output for these tokens.
644
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
645
  """
646
- task = kwargs.pop('task', None)
647
  if kwargs:
648
  for key, value in kwargs.items():
649
  if value is not None:
@@ -657,7 +651,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
657
  )
658
 
659
  hidden_states = self.embeddings(
660
- input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task=task
661
  )
662
  # TD [2022-12:18]: Don't need to force residual in fp32
663
  # BERT puts embedding LayerNorm before embedding dropout.
@@ -681,12 +675,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
681
  subset_mask = None
682
 
683
  sequence_output = self.encoder(
684
- hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task=task
685
  )
686
 
687
  if masked_tokens_mask is None:
688
  pooled_output = (
689
- self.pooler(sequence_output, task=task) if self.pooler is not None else None
690
  )
691
  else:
692
  # TD [2022-03-01]: the indexing here is very tricky.
@@ -700,7 +694,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
700
  pool_input = sequence_output[first_col_mask[subset_mask]]
701
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
702
  pooled_output = (
703
- self.pooler(pool_input, pool=False, task=task) if self.pooler is not None else None
704
  )
705
 
706
  if not return_dict:
@@ -1282,4 +1276,4 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1282
  logits=logits,
1283
  hidden_states=outputs.hidden_states,
1284
  attentions=outputs.attentions,
1285
- )
 
21
  import torch.utils.checkpoint
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
  from einops import rearrange
24
+ from transformers import PretrainedConfig, AutoTokenizer
25
  from transformers.modeling_utils import PreTrainedModel
26
  from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
  from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
 
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task_type=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
 
215
  if key_padding_mask is not None
216
  else None
217
  )
218
+ mixer_kwargs['task_type'] = task_type
219
  for layer in self.layers:
220
  if self._grad_checkpointing:
221
  hidden_states = torch.utils.checkpoint.checkpoint(
 
233
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
234
  hidden_states, key_padding_mask
235
  )
236
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task_type": task_type}
237
  if subset_mask is None:
238
  for layer in self.layers:
239
  if self._grad_checkpointing:
 
310
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
311
  self.activation = nn.Tanh()
312
 
313
+ def forward(self, hidden_states, pool=True, task_type=None):
314
  # We "pool" the model by simply taking the hidden state corresponding
315
  # to the first token.
316
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
317
 
318
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
319
  pooled_output = self.dense(first_token_tensor, **lora_kwargs)
 
443
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
444
 
445
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
446
+ self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
447
 
448
  @torch.inference_mode()
449
  def encode(
 
457
  device: Optional[torch.device] = None,
458
  normalize_embeddings: bool = False,
459
  truncate_dim: Optional[int] = None,
460
+ task_type: Optional[str] = None,
461
  **tokenizer_kwargs,
462
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
463
  """
 
496
  If convert_to_tensor, a stacked tensor is returned.
497
  If convert_to_numpy, a numpy matrix is returned.
498
  """
 
 
 
 
 
 
499
  is_training = self.training
500
  self.eval()
501
 
 
542
  )
543
  else:
544
  range_iter = range(0, len(sentences), batch_size)
545
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
546
  for i in range_iter:
547
  encoded_input = self.tokenizer(
548
  sentences[i : i + batch_size],
 
637
  layer output for these tokens.
638
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
639
  """
640
+ task_type = kwargs.pop('task_type', None)
641
  if kwargs:
642
  for key, value in kwargs.items():
643
  if value is not None:
 
651
  )
652
 
653
  hidden_states = self.embeddings(
654
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type=task_type
655
  )
656
  # TD [2022-12:18]: Don't need to force residual in fp32
657
  # BERT puts embedding LayerNorm before embedding dropout.
 
675
  subset_mask = None
676
 
677
  sequence_output = self.encoder(
678
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task_type=task_type
679
  )
680
 
681
  if masked_tokens_mask is None:
682
  pooled_output = (
683
+ self.pooler(sequence_output, task_type=task_type) if self.pooler is not None else None
684
  )
685
  else:
686
  # TD [2022-03-01]: the indexing here is very tricky.
 
694
  pool_input = sequence_output[first_col_mask[subset_mask]]
695
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
696
  pooled_output = (
697
+ self.pooler(pool_input, pool=False, task_type=task_type) if self.pooler is not None else None
698
  )
699
 
700
  if not return_dict:
 
1276
  logits=logits,
1277
  hidden_states=outputs.hidden_states,
1278
  attentions=outputs.attentions,
1279
+ )
rotary.py CHANGED
@@ -6,11 +6,13 @@ from typing import Optional, Tuple, Union
6
 
7
  import torch
8
  from einops import rearrange, repeat
9
- try:
10
- from flash_attn.ops.triton.rotary import apply_rotary
11
- except ImportError:
12
- def apply_rotary(*args, **kwargs):
13
- raise RuntimeError('RoPE requires flash-attention to be installed')
 
 
14
 
15
 
16
  def rotate_half(x, interleaved=False):
@@ -29,6 +31,10 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
29
  """
30
  ro_dim = cos.shape[-1] * 2
31
  assert ro_dim <= x.shape[-1]
 
 
 
 
32
  cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
33
  sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
34
  return torch.cat(
@@ -60,6 +66,7 @@ class ApplyRotaryEmb(torch.autograd.Function):
60
  interleaved=interleaved,
61
  inplace=inplace,
62
  )
 
63
  if isinstance(seqlen_offsets, int):
64
  ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
65
  ctx.seqlen_offsets = seqlen_offsets
@@ -82,6 +89,7 @@ class ApplyRotaryEmb(torch.autograd.Function):
82
  # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
83
  if not ctx.interleaved and not ctx.inplace:
84
  do = do.clone()
 
85
  dx = apply_rotary(
86
  do,
87
  cos,
@@ -150,21 +158,37 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
150
  # batch, seqlen, three, nheads, headdim = qkv.shape
151
  assert qkv.shape[-3] == 3
152
  if cos_k is None and sin_k is None and qkv.is_contiguous():
153
- # Call 1 kernel instead of 2 kernels
154
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
155
- # dimensions, we get the same tensor
156
- qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
157
- # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
158
- apply_rotary(
159
- qk,
160
- cos,
161
- sin,
162
- seqlen_offsets=seqlen_offsets,
163
- interleaved=interleaved,
164
- inplace=True,
165
- cu_seqlens=cu_seqlens,
166
- max_seqlen=max_seqlen,
167
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  else:
169
  cos_k = cos if cos_k is None else cos_k
170
  sin_k = sin if sin_k is None else sin_k
@@ -228,7 +252,6 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
228
  sin_k = sin if sin_k is None else sin_k
229
  dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
230
  apply_rotary(
231
-
232
  dq,
233
  cos,
234
  sin,
 
6
 
7
  import torch
8
  from einops import rearrange, repeat
9
+
10
+ if torch.cuda.is_available():
11
+ try:
12
+ from flash_attn.ops.triton.rotary import apply_rotary
13
+ except ImportError:
14
+ def apply_rotary(*args, **kwargs):
15
+ raise RuntimeError('RoPE requires flash-attention to be installed')
16
 
17
 
18
  def rotate_half(x, interleaved=False):
 
31
  """
32
  ro_dim = cos.shape[-1] * 2
33
  assert ro_dim <= x.shape[-1]
34
+ cos, sin = (
35
+ cos[:x.shape[1]],
36
+ sin[:x.shape[1]],
37
+ )
38
  cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
39
  sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
40
  return torch.cat(
 
66
  interleaved=interleaved,
67
  inplace=inplace,
68
  )
69
+
70
  if isinstance(seqlen_offsets, int):
71
  ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
72
  ctx.seqlen_offsets = seqlen_offsets
 
89
  # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
90
  if not ctx.interleaved and not ctx.inplace:
91
  do = do.clone()
92
+
93
  dx = apply_rotary(
94
  do,
95
  cos,
 
158
  # batch, seqlen, three, nheads, headdim = qkv.shape
159
  assert qkv.shape[-3] == 3
160
  if cos_k is None and sin_k is None and qkv.is_contiguous():
161
+
162
+ if torch.cuda.is_available():
163
+ # Call 1 kernel instead of 2 kernels
164
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
165
+ # dimensions, we get the same tensor
166
+ qk = rearrange(qkv[..., :2, :, :], "... t h d -> ... (t h) d")
167
+ # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
168
+ apply_rotary(
169
+ qk,
170
+ cos,
171
+ sin,
172
+ seqlen_offsets=seqlen_offsets,
173
+ interleaved=interleaved,
174
+ inplace=True,
175
+ cu_seqlens=cu_seqlens,
176
+ max_seqlen=max_seqlen,
177
+ )
178
+ else:
179
+ q_rot = apply_rotary_emb_torch(
180
+ qkv[:, :, 0],
181
+ cos,
182
+ sin,
183
+ interleaved=interleaved,
184
+ )
185
+ k_rot = apply_rotary_emb_torch(
186
+ qkv[:, :, 1],
187
+ cos,
188
+ sin,
189
+ interleaved=interleaved,
190
+ )
191
+ qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
192
  else:
193
  cos_k = cos if cos_k is None else cos_k
194
  sin_k = sin if sin_k is None else sin_k
 
252
  sin_k = sin if sin_k is None else sin_k
253
  dq, dk = dqkv[..., 0, :, :], dqkv[..., 1, :, :]
254
  apply_rotary(
 
255
  dq,
256
  cos,
257
  sin,