Jackmin108 commited on
Commit
7c7eafb
1 Parent(s): 7815d41

feat: evaluation/encode

Browse files

Signed-off-by: Meow <ongjackm@gmail.com>

Files changed (5) hide show
  1. embedding.py +2 -2
  2. mha.py +3 -3
  3. mlp.py +2 -2
  4. modeling_lora.py +11 -20
  5. modeling_xlm_roberta.py +4 -4
embedding.py CHANGED
@@ -55,7 +55,7 @@ class XLMRobertaEmbeddings(nn.Module):
55
  for task_id in unique_tasks:
56
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
57
  task_input_ids = input_ids[task_indices]
58
- task_embeddings = self.word_embeddings(task_input_ids, task_type=task_id)
59
  embeddings[task_indices] = task_embeddings
60
  else:
61
  embeddings = self.word_embeddings(input_ids)
@@ -73,7 +73,7 @@ class XLMRobertaEmbeddings(nn.Module):
73
  if adapter_mask is not None:
74
  unique_tasks = torch.unique(adapter_mask).tolist()
75
  for task_id in unique_tasks:
76
- task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_type=task_id)
77
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
78
  embeddings[task_indices] = embeddings[task_indices] + task_token_type_embeddings
79
  else:
 
55
  for task_id in unique_tasks:
56
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
57
  task_input_ids = input_ids[task_indices]
58
+ task_embeddings = self.word_embeddings(task_input_ids, task_id=task_id)
59
  embeddings[task_indices] = task_embeddings
60
  else:
61
  embeddings = self.word_embeddings(input_ids)
 
73
  if adapter_mask is not None:
74
  unique_tasks = torch.unique(adapter_mask).tolist()
75
  for task_id in unique_tasks:
76
+ task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_id=task_id)
77
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
78
  embeddings[task_indices] = embeddings[task_indices] + task_token_type_embeddings
79
  else:
mha.py CHANGED
@@ -655,9 +655,9 @@ class MHA(nn.Module):
655
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
656
  task_tensor = x[task_indices]
657
  if not self.return_residual:
658
- task_qkv = self.Wqkv(task_tensor, task_type=task_id)
659
  else:
660
- task_qkv, _ = self.Wqkv(task_tensor, task_type=task_id, residual=True)
661
  qkv[task_indices] = task_qkv
662
  else:
663
  if not self.return_residual:
@@ -759,7 +759,7 @@ class MHA(nn.Module):
759
  for task_id in unique_tasks:
760
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
761
  task_tensor = inp[task_indices]
762
- task_out = self.out_proj(task_tensor, task_type=task_id)
763
  out[task_indices] = task_out
764
  else:
765
  out = self.out_proj(inp)
 
655
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
656
  task_tensor = x[task_indices]
657
  if not self.return_residual:
658
+ task_qkv = self.Wqkv(task_tensor, task_id=task_id)
659
  else:
660
+ task_qkv, _ = self.Wqkv(task_tensor, task_id=task_id, residual=True)
661
  qkv[task_indices] = task_qkv
662
  else:
663
  if not self.return_residual:
 
759
  for task_id in unique_tasks:
760
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
761
  task_tensor = inp[task_indices]
762
+ task_out = self.out_proj(task_tensor, task_id=task_id)
763
  out[task_indices] = task_out
764
  else:
765
  out = self.out_proj(inp)
mlp.py CHANGED
@@ -56,7 +56,7 @@ class Mlp(nn.Module):
56
  for task_id in unique_tasks:
57
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
58
  task_tensor = x[task_indices]
59
- task_y = self.fc1(task_tensor, task_type=task_id)
60
  y[task_indices] = task_y
61
  else:
62
  y = self.fc1(x)
@@ -71,7 +71,7 @@ class Mlp(nn.Module):
71
  for task_id in unique_tasks:
72
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
73
  task_tensor = y[task_indices]
74
- task_out = self.fc2(task_tensor, task_type=task_id)
75
  out[task_indices] = task_out
76
  else:
77
  out = self.fc1(y)
 
56
  for task_id in unique_tasks:
57
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
58
  task_tensor = x[task_indices]
59
+ task_y = self.fc1(task_tensor, task_id=task_id)
60
  y[task_indices] = task_y
61
  else:
62
  y = self.fc1(x)
 
71
  for task_id in unique_tasks:
72
  task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
73
  task_tensor = y[task_indices]
74
+ task_out = self.fc2(task_tensor, task_id=task_id)
75
  out[task_indices] = task_out
76
  else:
77
  out = self.fc1(y)
modeling_lora.py CHANGED
@@ -161,7 +161,6 @@ class LoRAParametrization(nn.Module):
161
  rank: int,
162
  dropout_p: float,
163
  alpha: float,
164
- adaptation_map: dict,
165
  ):
166
  if isinstance(layer, nn.Linear):
167
  parametrize.register_parametrization(
@@ -176,14 +175,9 @@ class LoRAParametrization(nn.Module):
176
  ),
177
  )
178
 
179
- def new_forward(self, input, task_type, residual=False):
180
- if isinstance(task_type, str):
181
- task_idx = adaptation_map[task_type] if task_type else None
182
- else:
183
- task_idx = task_type
184
-
185
- if task_idx is not None:
186
- weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
187
  else:
188
  weights = self.weight
189
 
@@ -208,14 +202,9 @@ class LoRAParametrization(nn.Module):
208
  ),
209
  )
210
 
211
- def new_forward(self, input, task_type):
212
- if isinstance(task_type, str):
213
- task_idx = adaptation_map[task_type] if task_type else None
214
- else:
215
- task_idx = task_type
216
-
217
- if task_idx is not None:
218
- weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
219
  else:
220
  weights = self.weight
221
 
@@ -325,7 +314,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
325
  rank=rank,
326
  dropout_p=dropout_p,
327
  alpha=alpha,
328
- adaptation_map=self._adaptation_map,
329
  )
330
  )
331
 
@@ -348,6 +336,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
348
  @torch.inference_mode()
349
  def encode(
350
  self,
 
351
  *args,
352
  task_type: Optional[str] = None,
353
  **kwargs,
@@ -366,5 +355,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
366
  f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
367
  f"Alternatively, don't pass the `task_type` argument to disable LoRA."
368
  )
369
-
370
- return self.roberta.encode(*args, task_type=task_type, **kwargs)
 
 
 
161
  rank: int,
162
  dropout_p: float,
163
  alpha: float,
 
164
  ):
165
  if isinstance(layer, nn.Linear):
166
  parametrize.register_parametrization(
 
175
  ),
176
  )
177
 
178
+ def new_forward(self, input, task_id=None, residual=False):
179
+ if task_id is not None:
180
+ weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_id)
 
 
 
 
 
181
  else:
182
  weights = self.weight
183
 
 
202
  ),
203
  )
204
 
205
+ def new_forward(self, input, task_id=None):
206
+ if task_id is not None:
207
+ weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_id)
 
 
 
 
 
208
  else:
209
  weights = self.weight
210
 
 
314
  rank=rank,
315
  dropout_p=dropout_p,
316
  alpha=alpha,
 
317
  )
318
  )
319
 
 
336
  @torch.inference_mode()
337
  def encode(
338
  self,
339
+ sentences: Union[str, List[str]],
340
  *args,
341
  task_type: Optional[str] = None,
342
  **kwargs,
 
355
  f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
356
  f"Alternatively, don't pass the `task_type` argument to disable LoRA."
357
  )
358
+ task_id = self._adaptation_map[task_type]
359
+ num_examples = 1 if isinstance(sentences, str) else len(sentences)
360
+ adapter_mask = torch.full((num_examples,), task_id, dtype=torch.int32)
361
+ return self.roberta.encode(sentences, *args, adapter_mask=adapter_mask, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -321,7 +321,7 @@ class XLMRobertaPooler(nn.Module):
321
  for task_id in unique_tasks:
322
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
323
  task_first_token_tensor = first_token_tensor[task_indices]
324
- task_pooled_output = self.dense(task_first_token_tensor, task_type=task_id)
325
  pooled_output[task_indices] = task_pooled_output
326
  else:
327
  pooled_output = self.dense(first_token_tensor)
@@ -464,7 +464,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
464
  device: Optional[torch.device] = None,
465
  normalize_embeddings: bool = False,
466
  truncate_dim: Optional[int] = None,
467
- task_type: Optional[str] = None,
468
  **tokenizer_kwargs,
469
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
470
  """
@@ -549,14 +549,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
549
  )
550
  else:
551
  range_iter = range(0, len(sentences), batch_size)
552
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
553
  for i in range_iter:
554
  encoded_input = self.tokenizer(
555
  sentences[i : i + batch_size],
556
  return_tensors='pt',
557
  **tokenizer_kwargs,
558
  ).to(self.device)
559
- token_embs = self.forward(**encoded_input, **lora_kwargs)[0]
560
 
561
  # Accumulate in fp32 to avoid overflow
562
  token_embs = token_embs.float()
 
321
  for task_id in unique_tasks:
322
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
323
  task_first_token_tensor = first_token_tensor[task_indices]
324
+ task_pooled_output = self.dense(task_first_token_tensor, task_id=task_id)
325
  pooled_output[task_indices] = task_pooled_output
326
  else:
327
  pooled_output = self.dense(first_token_tensor)
 
464
  device: Optional[torch.device] = None,
465
  normalize_embeddings: bool = False,
466
  truncate_dim: Optional[int] = None,
467
+ adapter_mask: Optional[torch.Tensor] = None,
468
  **tokenizer_kwargs,
469
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
470
  """
 
549
  )
550
  else:
551
  range_iter = range(0, len(sentences), batch_size)
552
+ lora_arguments = {'adapter_mask': adapter_mask} if adapter_mask is not None else {}
553
  for i in range_iter:
554
  encoded_input = self.tokenizer(
555
  sentences[i : i + batch_size],
556
  return_tensors='pt',
557
  **tokenizer_kwargs,
558
  ).to(self.device)
559
+ token_embs = self.forward(**encoded_input, **lora_arguments)[0]
560
 
561
  # Accumulate in fp32 to avoid overflow
562
  token_embs = token_embs.float()