jupyterjazz commited on
Commit
e860caa
1 Parent(s): 6a92924
Files changed (6) hide show
  1. block.py +1 -1
  2. embedding.py +4 -3
  3. mha.py +11 -3
  4. mlp.py +4 -3
  5. modeling_lora.py +51 -91
  6. modeling_xlm_roberta.py +15 -11
block.py CHANGED
@@ -233,7 +233,7 @@ class Block(nn.Module):
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
- mlp_out = self.mlp(hidden_states)
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
 
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
+ mlp_out = self.mlp(hidden_states, task_type=mixer_kwargs.get('task_type'))
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
embedding.py CHANGED
@@ -40,14 +40,15 @@ class XLMRobertaEmbeddings(nn.Module):
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
- def forward(self, input_ids, position_ids=None, token_type_ids=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
- embeddings = self.word_embeddings(input_ids)
 
51
  if self.max_position_embeddings > 0:
52
  if position_ids is None:
53
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
@@ -57,6 +58,6 @@ class XLMRobertaEmbeddings(nn.Module):
57
  if self.type_vocab_size > 0:
58
  if token_type_ids is None:
59
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
61
  embeddings = embeddings + token_type_embeddings
62
  return embeddings
 
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
51
+ embeddings = self.word_embeddings(input_ids, **lora_kwargs)
52
  if self.max_position_embeddings > 0:
53
  if position_ids is None:
54
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
 
58
  if self.type_vocab_size > 0:
59
  if token_type_ids is None:
60
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
61
+ token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
62
  embeddings = embeddings + token_type_embeddings
63
  return embeddings
mha.py CHANGED
@@ -450,6 +450,7 @@ class MHA(nn.Module):
450
 
451
  if fused_bias_fc and FusedDense is None:
452
  raise ImportError("fused_dense is not installed")
 
453
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
454
  linear_resid_cls = (
455
  LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
@@ -589,6 +590,7 @@ class MHA(nn.Module):
589
  max_seqlen=None,
590
  mixer_subset=None,
591
  inference_params=None,
 
592
  **kwargs,
593
  ):
594
  """
@@ -643,10 +645,14 @@ class MHA(nn.Module):
643
  batch, seqlen = x.shape[:2]
644
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
645
  assert x_kv is None and mixer_subset is None
 
646
  if not self.return_residual:
647
- qkv = self.Wqkv(x)
648
  else:
649
- qkv, x = self.Wqkv(x)
 
 
 
650
  if self.dwconv:
651
  qkv = rearrange(
652
  self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
@@ -731,5 +737,7 @@ class MHA(nn.Module):
731
  context = self._update_kvcache_attention(q, kv, inference_params)
732
  else:
733
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
734
- out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
 
 
735
  return out if not self.return_residual else (out, x)
 
450
 
451
  if fused_bias_fc and FusedDense is None:
452
  raise ImportError("fused_dense is not installed")
453
+
454
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
455
  linear_resid_cls = (
456
  LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
 
590
  max_seqlen=None,
591
  mixer_subset=None,
592
  inference_params=None,
593
+ task_type=None,
594
  **kwargs,
595
  ):
596
  """
 
645
  batch, seqlen = x.shape[:2]
646
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
647
  assert x_kv is None and mixer_subset is None
648
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
649
  if not self.return_residual:
650
+ qkv = self.Wqkv(x, **lora_kwargs)
651
  else:
652
+ if lora_kwargs:
653
+ lora_kwargs['residual'] = True
654
+ qkv, x = self.Wqkv(x, **lora_kwargs)
655
+
656
  if self.dwconv:
657
  qkv = rearrange(
658
  self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
737
  context = self._update_kvcache_attention(q, kv, inference_params)
738
  else:
739
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
740
+
741
+ lora_kwargs.pop('residual', None)
742
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), **lora_kwargs)
743
  return out if not self.return_residual else (out, x)
mlp.py CHANGED
@@ -47,10 +47,11 @@ class Mlp(nn.Module):
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
- def forward(self, x):
51
- y = self.fc1(x)
 
52
  y = self.activation(y)
53
- y = self.fc2(y)
54
  return y if not self.return_residual else (y, x)
55
 
56
 
 
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
+ def forward(self, x, task_type=None):
51
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
52
+ y = self.fc1(x, **lora_kwargs)
53
  y = self.activation(y)
54
+ y = self.fc2(y, **lora_kwargs)
55
  return y if not self.return_residual else (y, x)
56
 
57
 
modeling_lora.py CHANGED
@@ -9,6 +9,7 @@ import torch
9
  import torch.nn.utils.parametrize as parametrize
10
  from torch import nn
11
  from torch.nn import Parameter
 
12
  from transformers import PretrainedConfig
13
 
14
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
@@ -88,22 +89,19 @@ class LoRAParametrization(nn.Module):
88
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
89
  persistent=False,
90
  )
91
- self.forward_fn = lambda x: x
92
- self.current_task = None
93
 
94
  def _dropout(self, A):
95
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
96
  return A * self.lora_dropout(self.lora_dropout_mask)
97
 
98
- def lora_forward(self, X):
99
- assert self.current_task is not None
100
  return (
101
  X
102
  + torch.matmul(
103
  *self.swap(
104
  (
105
- self.lora_B[self.current_task],
106
- self.dropout_fn(self.lora_A[self.current_task]),
107
  )
108
  )
109
  ).view(X.shape)
@@ -111,19 +109,7 @@ class LoRAParametrization(nn.Module):
111
  )
112
 
113
  def forward(self, X):
114
- return self.forward_fn(X)
115
-
116
- @property
117
- def current_task(self):
118
- return self._current_task
119
-
120
- @current_task.setter
121
- def current_task(self, task: Union[None, int]):
122
- self._current_task = task
123
- if task is None:
124
- self.forward_fn = lambda x: x
125
- else:
126
- self.forward_fn = self.lora_forward
127
 
128
  @classmethod
129
  def from_linear(
@@ -175,6 +161,7 @@ class LoRAParametrization(nn.Module):
175
  rank: int,
176
  dropout_p: float,
177
  alpha: float,
 
178
  ):
179
  if isinstance(layer, nn.Linear):
180
  parametrize.register_parametrization(
@@ -188,6 +175,22 @@ class LoRAParametrization(nn.Module):
188
  alpha=alpha,
189
  ),
190
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  elif isinstance(layer, nn.Embedding):
192
  parametrize.register_parametrization(
193
  layer,
@@ -201,10 +204,20 @@ class LoRAParametrization(nn.Module):
201
  ),
202
  )
203
 
204
- @staticmethod
205
- def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
206
- if isinstance(layer, LoRAParametrization):
207
- layer.current_task = task_idx
 
 
 
 
 
 
 
 
 
 
208
 
209
 
210
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
@@ -251,9 +264,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
251
  alpha=self._alpha,
252
  )
253
  self.main_params_trainable = config.lora_main_params_trainable
254
- self._task_idx = None
255
- # By default, disable LoRA until it's specified which adapter/task to use
256
- self.current_task = None
257
 
258
  @property
259
  def main_params_trainable(self):
@@ -307,51 +318,11 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
307
  rank=rank,
308
  dropout_p=dropout_p,
309
  alpha=alpha,
 
310
  )
311
  )
312
 
313
- @property
314
- def current_task(self):
315
- """Which LoRA is currently selected
316
- :return: Integer or None (when LoRA is disabled)
317
- """
318
- return self._task_idx
319
-
320
- @current_task.setter
321
- def current_task(self, task_name: Union[None, str]):
322
- """Set the LoRA that is to be used.
323
- The LoRA is specified by `task_idx`, which may be an integer >= 0,
324
- indexing the available LoRAs. If it is None, no LoRA is used.
325
- :param task_name: Which LoRA to use
326
- :return:
327
- """
328
- if task_name and task_name not in self._lora_adaptations:
329
- raise ValueError(
330
- f"Unsupported task '{task_name}'. "
331
- f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
332
- f"Alternatively, set `task` to `None` if you want to disable LoRA."
333
- )
334
- task_idx = self._adaptation_map[task_name] if task_name else None
335
- if self._task_idx != task_idx:
336
- # In this case, we need to update the LoRAs everywhere
337
- self._task_idx = task_idx
338
- self.apply(
339
- partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
340
- )
341
-
342
- def forward(self, *args, task_type: Union[str, None] = None, **kwargs):
343
- if task_type:
344
- self.current_task = task_type
345
- else:
346
- input_ids = kwargs["input_ids"]
347
- input_text = self.roberta.tokenizer.decode(input_ids[0], skip_special_tokens=True)
348
- for task_name, prompt in self._lora_prompts.items():
349
- if input_text.startswith(prompt):
350
- self.current_task = task_name
351
- break
352
- else:
353
- self.current_task = None # No task-specific adapter is found, just use the general-purpose weights
354
-
355
  return self.roberta(*args, **kwargs)
356
 
357
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
@@ -371,33 +342,22 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
371
  def encode(
372
  self,
373
  *args,
374
- task_type: Union[str, None] = None,
375
  **kwargs,
376
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
377
  """
378
  Computes sentence embeddings
379
 
380
- task(`str`, *optional*, defaults to `LORA_NO_UPDATE`):
381
- Specifies the task for which the encoding is intended. This parameter controls the
382
- use of specialized LoRA adapters that are tuned for specific tasks. If `task` is set
383
- to `None`, all LoRA adapters are disabled, and the model reverts to its original,
384
- general-purpose weights. If `task` is set to a specific LoRA adaptation, that adaptation
385
- is activated.
386
  """
387
- if task_type:
388
- self.current_task = task_type
389
- else: # infer the task from the input text
390
- input_text = args[0][0] if isinstance(args[0], list) else args[0] # take only the first sentence
391
- for task_name, prompt in self._lora_prompts.items():
392
- if input_text.startswith(prompt):
393
- self.current_task = task_name
394
- break
395
- else:
396
- warnings.warn(
397
- f"Task-specific embeddings are disabled. To enable, specify the `task` "
398
- f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
399
- category=UserWarning,
400
- )
401
- self.current_task = None # No task-specific adapter is found, just use the general-purpose weights
402
 
403
- return self.roberta.encode(*args, **kwargs)
 
9
  import torch.nn.utils.parametrize as parametrize
10
  from torch import nn
11
  from torch.nn import Parameter
12
+ from torch.nn import functional as F
13
  from transformers import PretrainedConfig
14
 
15
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
 
89
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
90
  persistent=False,
91
  )
 
 
92
 
93
  def _dropout(self, A):
94
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
95
  return A * self.lora_dropout(self.lora_dropout_mask)
96
 
97
+ def lora_forward(self, X, current_task):
 
98
  return (
99
  X
100
  + torch.matmul(
101
  *self.swap(
102
  (
103
+ self.lora_B[current_task],
104
+ self.dropout_fn(self.lora_A[current_task]),
105
  )
106
  )
107
  ).view(X.shape)
 
109
  )
110
 
111
  def forward(self, X):
112
+ return X
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  @classmethod
115
  def from_linear(
 
161
  rank: int,
162
  dropout_p: float,
163
  alpha: float,
164
+ adaptation_map: dict,
165
  ):
166
  if isinstance(layer, nn.Linear):
167
  parametrize.register_parametrization(
 
175
  alpha=alpha,
176
  ),
177
  )
178
+
179
+ def new_forward(self, input, task_type, residual=False):
180
+ task_idx = adaptation_map[task_type] if task_type else None
181
+ if task_idx is not None:
182
+ weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
183
+ else:
184
+ weights = self.weight
185
+
186
+ out = F.linear(input, weights, self.bias)
187
+
188
+ if residual:
189
+ return out, input
190
+ return out
191
+
192
+ layer.forward = new_forward.__get__(layer, layer.__class__)
193
+
194
  elif isinstance(layer, nn.Embedding):
195
  parametrize.register_parametrization(
196
  layer,
 
204
  ),
205
  )
206
 
207
+ def new_forward(self, input, task_type):
208
+ task_idx = adaptation_map[task_type] if task_type else None
209
+ if task_idx is not None:
210
+ weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
211
+ else:
212
+ weights = self.weight
213
+
214
+ out = F.embedding(
215
+ input, weights, self.padding_idx, self.max_norm,
216
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
217
+
218
+ return out
219
+
220
+ layer.forward = new_forward.__get__(layer, layer.__class__)
221
 
222
 
223
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
 
264
  alpha=self._alpha,
265
  )
266
  self.main_params_trainable = config.lora_main_params_trainable
267
+
 
 
268
 
269
  @property
270
  def main_params_trainable(self):
 
318
  rank=rank,
319
  dropout_p=dropout_p,
320
  alpha=alpha,
321
+ adaptation_map=self._adaptation_map,
322
  )
323
  )
324
 
325
+ def forward(self, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  return self.roberta(*args, **kwargs)
327
 
328
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
 
342
  def encode(
343
  self,
344
  *args,
345
+ task_type: Optional[str] = None,
346
  **kwargs,
347
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
348
  """
349
  Computes sentence embeddings
350
 
351
+ task_type(`str`, *optional*, defaults to `None`):
352
+ Specifies the task for which the encoding is intended. If `task_type` is not provide,
353
+ all LoRA adapters are disabled, and the model reverts to its original,
354
+ general-purpose weights.
 
 
355
  """
356
+ if task_type and task_type not in self._lora_adaptations:
357
+ raise ValueError(
358
+ f"Unsupported task '{task_type}'. "
359
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
360
+ f"Alternatively, don't pass the `task_type` argument to disable LoRA."
361
+ )
 
 
 
 
 
 
 
 
 
362
 
363
+ return self.roberta.encode(*args, task_type=task_type, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -204,7 +204,7 @@ class XLMRobertaEncoder(nn.Module):
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
- def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
@@ -215,6 +215,7 @@ class XLMRobertaEncoder(nn.Module):
215
  if key_padding_mask is not None
216
  else None
217
  )
 
218
  for layer in self.layers:
219
  if self._grad_checkpointing:
220
  hidden_states = torch.utils.checkpoint.checkpoint(
@@ -232,7 +233,7 @@ class XLMRobertaEncoder(nn.Module):
232
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
233
  hidden_states, key_padding_mask
234
  )
235
- mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
236
  if subset_mask is None:
237
  for layer in self.layers:
238
  if self._grad_checkpointing:
@@ -309,11 +310,13 @@ class XLMRobertaPooler(nn.Module):
309
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
310
  self.activation = nn.Tanh()
311
 
312
- def forward(self, hidden_states, pool=True):
313
  # We "pool" the model by simply taking the hidden state corresponding
314
  # to the first token.
 
 
315
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
316
- pooled_output = self.dense(first_token_tensor)
317
  pooled_output = self.activation(pooled_output)
318
  return pooled_output
319
 
@@ -454,6 +457,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
454
  device: Optional[torch.device] = None,
455
  normalize_embeddings: bool = False,
456
  truncate_dim: Optional[int] = None,
 
457
  **tokenizer_kwargs,
458
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
459
  """
@@ -538,14 +542,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
538
  )
539
  else:
540
  range_iter = range(0, len(sentences), batch_size)
541
-
542
  for i in range_iter:
543
  encoded_input = self.tokenizer(
544
  sentences[i : i + batch_size],
545
  return_tensors='pt',
546
  **tokenizer_kwargs,
547
  ).to(self.device)
548
- token_embs = self.forward(**encoded_input)[0]
549
 
550
  # Accumulate in fp32 to avoid overflow
551
  token_embs = token_embs.float()
@@ -633,7 +637,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
633
  layer output for these tokens.
634
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
635
  """
636
-
637
  if kwargs:
638
  for key, value in kwargs.items():
639
  if value is not None:
@@ -647,7 +651,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
647
  )
648
 
649
  hidden_states = self.embeddings(
650
- input_ids, position_ids=position_ids, token_type_ids=token_type_ids
651
  )
652
  # TD [2022-12:18]: Don't need to force residual in fp32
653
  # BERT puts embedding LayerNorm before embedding dropout.
@@ -671,12 +675,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
671
  subset_mask = None
672
 
673
  sequence_output = self.encoder(
674
- hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
675
  )
676
 
677
  if masked_tokens_mask is None:
678
  pooled_output = (
679
- self.pooler(sequence_output) if self.pooler is not None else None
680
  )
681
  else:
682
  # TD [2022-03-01]: the indexing here is very tricky.
@@ -690,7 +694,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
690
  pool_input = sequence_output[first_col_mask[subset_mask]]
691
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
692
  pooled_output = (
693
- self.pooler(pool_input, pool=False) if self.pooler is not None else None
694
  )
695
 
696
  if not return_dict:
 
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task_type=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
 
215
  if key_padding_mask is not None
216
  else None
217
  )
218
+ mixer_kwargs['task_type'] = task_type
219
  for layer in self.layers:
220
  if self._grad_checkpointing:
221
  hidden_states = torch.utils.checkpoint.checkpoint(
 
233
  hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
234
  hidden_states, key_padding_mask
235
  )
236
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task_type": task_type}
237
  if subset_mask is None:
238
  for layer in self.layers:
239
  if self._grad_checkpointing:
 
310
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
311
  self.activation = nn.Tanh()
312
 
313
+ def forward(self, hidden_states, pool=True, task_type=None):
314
  # We "pool" the model by simply taking the hidden state corresponding
315
  # to the first token.
316
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
317
+
318
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
319
+ pooled_output = self.dense(first_token_tensor, **lora_kwargs)
320
  pooled_output = self.activation(pooled_output)
321
  return pooled_output
322
 
 
457
  device: Optional[torch.device] = None,
458
  normalize_embeddings: bool = False,
459
  truncate_dim: Optional[int] = None,
460
+ task_type: Optional[str] = None,
461
  **tokenizer_kwargs,
462
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
463
  """
 
542
  )
543
  else:
544
  range_iter = range(0, len(sentences), batch_size)
545
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
546
  for i in range_iter:
547
  encoded_input = self.tokenizer(
548
  sentences[i : i + batch_size],
549
  return_tensors='pt',
550
  **tokenizer_kwargs,
551
  ).to(self.device)
552
+ token_embs = self.forward(**encoded_input, **lora_kwargs)[0]
553
 
554
  # Accumulate in fp32 to avoid overflow
555
  token_embs = token_embs.float()
 
637
  layer output for these tokens.
638
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
639
  """
640
+ task_type = kwargs.pop('task_type', None)
641
  if kwargs:
642
  for key, value in kwargs.items():
643
  if value is not None:
 
651
  )
652
 
653
  hidden_states = self.embeddings(
654
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type=task_type
655
  )
656
  # TD [2022-12:18]: Don't need to force residual in fp32
657
  # BERT puts embedding LayerNorm before embedding dropout.
 
675
  subset_mask = None
676
 
677
  sequence_output = self.encoder(
678
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task_type=task_type
679
  )
680
 
681
  if masked_tokens_mask is None:
682
  pooled_output = (
683
+ self.pooler(sequence_output, task_type=task_type) if self.pooler is not None else None
684
  )
685
  else:
686
  # TD [2022-03-01]: the indexing here is very tricky.
 
694
  pool_input = sequence_output[first_col_mask[subset_mask]]
695
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
696
  pooled_output = (
697
+ self.pooler(pool_input, pool=False, task_type=task_type) if self.pooler is not None else None
698
  )
699
 
700
  if not return_dict: