jupyterjazz commited on
Commit
509511d
1 Parent(s): eefe43c

refactor: finalize impl

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (6) hide show
  1. block.py +1 -1
  2. embedding.py +6 -3
  3. mha.py +12 -6
  4. mlp.py +6 -3
  5. modeling_lora.py +1 -56
  6. modeling_xlm_roberta.py +13 -8
block.py CHANGED
@@ -233,7 +233,7 @@ class Block(nn.Module):
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
- mlp_out = self.mlp(hidden_states)
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
 
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
+ mlp_out = self.mlp(hidden_states, task=mixer_kwargs.get('task'))
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
embedding.py CHANGED
@@ -40,14 +40,17 @@ class XLMRobertaEmbeddings(nn.Module):
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
- def forward(self, input_ids, position_ids=None, token_type_ids=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
- embeddings = self.word_embeddings(input_ids, task='sts')
 
 
 
51
  if self.max_position_embeddings > 0:
52
  if position_ids is None:
53
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
@@ -57,6 +60,6 @@ class XLMRobertaEmbeddings(nn.Module):
57
  if self.type_vocab_size > 0:
58
  if token_type_ids is None:
59
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
- token_type_embeddings = self.token_type_embeddings(token_type_ids, task='sts')
61
  embeddings = embeddings + token_type_embeddings
62
  return embeddings
 
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None, task=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
+ lora_kwargs = {}
51
+ if task is not None:
52
+ lora_kwargs['task'] = task
53
+ embeddings = self.word_embeddings(input_ids, **lora_kwargs)
54
  if self.max_position_embeddings > 0:
55
  if position_ids is None:
56
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
 
60
  if self.type_vocab_size > 0:
61
  if token_type_ids is None:
62
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
63
+ token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
64
  embeddings = embeddings + token_type_embeddings
65
  return embeddings
mha.py CHANGED
@@ -340,9 +340,8 @@ class CrossAttention(nn.Module):
340
  class LinearResidual(nn.Linear):
341
  """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
342
 
343
- def forward(self, input: torch.Tensor, task=None) -> torch.Tensor:
344
- print('aq vafshe ar modis?')
345
- return super().forward(input, task=task), input
346
 
347
 
348
  def _update_kv_cache(kv, inference_params, layer_idx):
@@ -591,6 +590,7 @@ class MHA(nn.Module):
591
  max_seqlen=None,
592
  mixer_subset=None,
593
  inference_params=None,
 
594
  **kwargs,
595
  ):
596
  """
@@ -645,10 +645,15 @@ class MHA(nn.Module):
645
  batch, seqlen = x.shape[:2]
646
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
647
  assert x_kv is None and mixer_subset is None
 
 
 
 
 
648
  if not self.return_residual:
649
- qkv = self.Wqkv(x)
650
  else:
651
- qkv, x = self.Wqkv(x, task='query', residual=True)
652
 
653
  if self.dwconv:
654
  qkv = rearrange(
@@ -734,5 +739,6 @@ class MHA(nn.Module):
734
  context = self._update_kvcache_attention(q, kv, inference_params)
735
  else:
736
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
737
- out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), task='passage')
 
738
  return out if not self.return_residual else (out, x)
 
340
  class LinearResidual(nn.Linear):
341
  """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
342
 
343
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
344
+ return super().forward(input), input
 
345
 
346
 
347
  def _update_kv_cache(kv, inference_params, layer_idx):
 
590
  max_seqlen=None,
591
  mixer_subset=None,
592
  inference_params=None,
593
+ task=None,
594
  **kwargs,
595
  ):
596
  """
 
645
  batch, seqlen = x.shape[:2]
646
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
647
  assert x_kv is None and mixer_subset is None
648
+ lora_kwargs = {}
649
+ if task:
650
+ lora_kwargs['task'] = task
651
+ lora_kwargs['residual'] = self.return_residual
652
+
653
  if not self.return_residual:
654
+ qkv = self.Wqkv(x, **lora_kwargs)
655
  else:
656
+ qkv, x = self.Wqkv(x, **lora_kwargs)
657
 
658
  if self.dwconv:
659
  qkv = rearrange(
 
739
  context = self._update_kvcache_attention(q, kv, inference_params)
740
  else:
741
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
742
+ lora_kwargs.pop('residual', None)
743
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), **lora_kwargs)
744
  return out if not self.return_residual else (out, x)
mlp.py CHANGED
@@ -47,10 +47,13 @@ class Mlp(nn.Module):
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
- def forward(self, x):
51
- y = self.fc1(x, task='clustering')
 
 
 
52
  y = self.activation(y)
53
- y = self.fc2(y, task='sts')
54
  return y if not self.return_residual else (y, x)
55
 
56
 
 
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
+ def forward(self, x, task):
51
+ lora_kwargs = {}
52
+ if task:
53
+ lora_kwargs['task'] = task
54
+ y = self.fc1(x, **lora_kwargs)
55
  y = self.activation(y)
56
+ y = self.fc2(y, **lora_kwargs)
57
  return y if not self.return_residual else (y, x)
58
 
59
 
modeling_lora.py CHANGED
@@ -92,8 +92,6 @@ class LoRAParametrization(nn.Module):
92
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
93
  persistent=False,
94
  )
95
- self.forward_fn = lambda x: x
96
- self.current_task = None
97
 
98
  def _dropout(self, A):
99
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
@@ -116,18 +114,6 @@ class LoRAParametrization(nn.Module):
116
  def forward(self, X):
117
  return X
118
 
119
- @property
120
- def current_task(self):
121
- return self._current_task
122
-
123
- @current_task.setter
124
- def current_task(self, task: Union[None, int]):
125
- self._current_task = task
126
- if task is None:
127
- self.forward_fn = lambda x: x
128
- else:
129
- self.forward_fn = self.lora_forward
130
-
131
  @classmethod
132
  def from_linear(
133
  cls,
@@ -239,12 +225,6 @@ class LoRAParametrization(nn.Module):
239
  layer.forward = new_forward.__get__(layer, layer.__class__)
240
 
241
 
242
- @staticmethod
243
- def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
244
- if isinstance(layer, LoRAParametrization):
245
- layer.current_task = task_idx
246
-
247
-
248
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
249
  def __init__(
250
  self,
@@ -279,9 +259,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
279
  alpha=self._alpha,
280
  )
281
  self.main_params_trainable = config.lora_main_params_trainable
282
- self._task_idx = None
283
- # By default, disable LoRA until it's specified which adapter/task to use
284
- self.current_task = None
285
 
286
 
287
  @property
@@ -340,39 +317,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
340
  )
341
  )
342
 
343
- @property
344
- def current_task(self):
345
- """Which LoRA is currently selected
346
- :return: Integer or None (when LoRA is disabled)
347
- """
348
- return self._task_idx
349
-
350
- @current_task.setter
351
- def current_task(self, task_name: Union[None, str]):
352
- """Set the LoRA that is to be used.
353
- The LoRA is specified by `task_idx`, which may be an integer >= 0,
354
- indexing the available LoRAs. If it is None, no LoRA is used.
355
- :param task_name: Which LoRA to use
356
- :return:
357
- """
358
- if task_name and task_name not in self._lora_adaptations:
359
- raise ValueError(
360
- f"Unsupported task '{task_name}'. "
361
- f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
362
- f"Alternatively, set `task` to `None` if you want to disable LoRA."
363
- )
364
- task_idx = self._adaptation_map[task_name] if task_name else None
365
- # if self._task_idx != task_idx:
366
- # # In this case, we need to update the LoRAs everywhere
367
- # self._task_idx = task_idx
368
- # self.apply(
369
- # partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
370
- # )
371
-
372
- def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
373
- if task != LORA_NO_UPDATE:
374
- self.current_task = task
375
-
376
  return self.roberta(*args, **kwargs)
377
 
378
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
 
92
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
93
  persistent=False,
94
  )
 
 
95
 
96
  def _dropout(self, A):
97
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
 
114
  def forward(self, X):
115
  return X
116
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  @classmethod
118
  def from_linear(
119
  cls,
 
225
  layer.forward = new_forward.__get__(layer, layer.__class__)
226
 
227
 
 
 
 
 
 
 
228
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
229
  def __init__(
230
  self,
 
259
  alpha=self._alpha,
260
  )
261
  self.main_params_trainable = config.lora_main_params_trainable
 
 
 
262
 
263
 
264
  @property
 
317
  )
318
  )
319
 
320
+ def forward(self, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  return self.roberta(*args, **kwargs)
322
 
323
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
modeling_xlm_roberta.py CHANGED
@@ -215,6 +215,7 @@ class XLMRobertaEncoder(nn.Module):
215
  if key_padding_mask is not None
216
  else None
217
  )
 
218
  for layer in self.layers:
219
  if self._grad_checkpointing:
220
  hidden_states = torch.utils.checkpoint.checkpoint(
@@ -232,7 +233,7 @@ class XLMRobertaEncoder(nn.Module):
232
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
233
  hidden_states, key_padding_mask
234
  )
235
- mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
236
  if subset_mask is None:
237
  for layer in self.layers:
238
  if self._grad_checkpointing:
@@ -309,11 +310,15 @@ class XLMRobertaPooler(nn.Module):
309
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
310
  self.activation = nn.Tanh()
311
 
312
- def forward(self, hidden_states, pool=True):
313
  # We "pool" the model by simply taking the hidden state corresponding
314
  # to the first token.
 
 
 
 
315
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
316
- pooled_output = self.dense(first_token_tensor, task='passage')
317
  pooled_output = self.activation(pooled_output)
318
  return pooled_output
319
 
@@ -639,7 +644,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
639
  layer output for these tokens.
640
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
641
  """
642
-
643
  if kwargs:
644
  for key, value in kwargs.items():
645
  if value is not None:
@@ -653,7 +658,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
653
  )
654
 
655
  hidden_states = self.embeddings(
656
- input_ids, position_ids=position_ids, token_type_ids=token_type_ids
657
  )
658
  # TD [2022-12:18]: Don't need to force residual in fp32
659
  # BERT puts embedding LayerNorm before embedding dropout.
@@ -677,12 +682,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
677
  subset_mask = None
678
 
679
  sequence_output = self.encoder(
680
- hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
681
  )
682
 
683
  if masked_tokens_mask is None:
684
  pooled_output = (
685
- self.pooler(sequence_output) if self.pooler is not None else None
686
  )
687
  else:
688
  # TD [2022-03-01]: the indexing here is very tricky.
@@ -696,7 +701,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
696
  pool_input = sequence_output[first_col_mask[subset_mask]]
697
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
698
  pooled_output = (
699
- self.pooler(pool_input, pool=False) if self.pooler is not None else None
700
  )
701
 
702
  if not return_dict:
 
215
  if key_padding_mask is not None
216
  else None
217
  )
218
+ mixer_kwargs['task'] = task
219
  for layer in self.layers:
220
  if self._grad_checkpointing:
221
  hidden_states = torch.utils.checkpoint.checkpoint(
 
233
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
234
  hidden_states, key_padding_mask
235
  )
236
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task": task}
237
  if subset_mask is None:
238
  for layer in self.layers:
239
  if self._grad_checkpointing:
 
310
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
311
  self.activation = nn.Tanh()
312
 
313
+ def forward(self, hidden_states, pool=True, task=None):
314
  # We "pool" the model by simply taking the hidden state corresponding
315
  # to the first token.
316
+ lora_kwargs = {}
317
+ if task:
318
+ lora_kwargs['task'] = task
319
+
320
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
321
+ pooled_output = self.dense(first_token_tensor, **lora_kwargs)
322
  pooled_output = self.activation(pooled_output)
323
  return pooled_output
324
 
 
644
  layer output for these tokens.
645
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
646
  """
647
+ task = kwargs.pop('task', None)
648
  if kwargs:
649
  for key, value in kwargs.items():
650
  if value is not None:
 
658
  )
659
 
660
  hidden_states = self.embeddings(
661
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task=task
662
  )
663
  # TD [2022-12:18]: Don't need to force residual in fp32
664
  # BERT puts embedding LayerNorm before embedding dropout.
 
682
  subset_mask = None
683
 
684
  sequence_output = self.encoder(
685
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task=task
686
  )
687
 
688
  if masked_tokens_mask is None:
689
  pooled_output = (
690
+ self.pooler(sequence_output, task=task) if self.pooler is not None else None
691
  )
692
  else:
693
  # TD [2022-03-01]: the indexing here is very tricky.
 
701
  pool_input = sequence_output[first_col_mask[subset_mask]]
702
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
703
  pooled_output = (
704
+ self.pooler(pool_input, pool=False, task=task) if self.pooler is not None else None
705
  )
706
 
707
  if not return_dict: