jupyterjazz
commited on
Commit
•
ffd672d
1
Parent(s):
b20a611
refactor-task-type-to-task (#43)
Browse files- rename task type (3afddee7275504c48afc63049db9124f9e2871ce)
- modeling_lora.py +10 -10
- modeling_xlm_roberta.py +1 -1
modeling_lora.py
CHANGED
@@ -367,35 +367,35 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
367 |
self,
|
368 |
sentences: Union[str, List[str]],
|
369 |
*args,
|
370 |
-
|
371 |
**kwargs,
|
372 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
373 |
"""
|
374 |
Computes sentence embeddings.
|
375 |
sentences(`str` or `List[str]`):
|
376 |
Sentence or sentences to be encoded
|
377 |
-
|
378 |
-
Specifies the task for which the encoding is intended. If `
|
379 |
all LoRA adapters are disabled, and the model reverts to its original,
|
380 |
general-purpose weights.
|
381 |
"""
|
382 |
-
if
|
383 |
raise ValueError(
|
384 |
-
f"Unsupported task '{
|
385 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
386 |
-
f"Alternatively, don't pass the `
|
387 |
)
|
388 |
adapter_mask = None
|
389 |
-
if
|
390 |
-
task_id = self._adaptation_map[
|
391 |
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
392 |
adapter_mask = torch.full(
|
393 |
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
394 |
)
|
395 |
if isinstance(sentences, str):
|
396 |
-
sentences = self._task_instructions[
|
397 |
else:
|
398 |
-
sentences = [self._task_instructions[
|
399 |
return self.roberta.encode(
|
400 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
401 |
)
|
|
|
367 |
self,
|
368 |
sentences: Union[str, List[str]],
|
369 |
*args,
|
370 |
+
task: Optional[str] = None,
|
371 |
**kwargs,
|
372 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
373 |
"""
|
374 |
Computes sentence embeddings.
|
375 |
sentences(`str` or `List[str]`):
|
376 |
Sentence or sentences to be encoded
|
377 |
+
task(`str`, *optional*, defaults to `None`):
|
378 |
+
Specifies the task for which the encoding is intended. If `task` is not provided,
|
379 |
all LoRA adapters are disabled, and the model reverts to its original,
|
380 |
general-purpose weights.
|
381 |
"""
|
382 |
+
if task and task not in self._lora_adaptations:
|
383 |
raise ValueError(
|
384 |
+
f"Unsupported task '{task}'. "
|
385 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
386 |
+
f"Alternatively, don't pass the `task` argument to disable LoRA."
|
387 |
)
|
388 |
adapter_mask = None
|
389 |
+
if task:
|
390 |
+
task_id = self._adaptation_map[task]
|
391 |
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
392 |
adapter_mask = torch.full(
|
393 |
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
394 |
)
|
395 |
if isinstance(sentences, str):
|
396 |
+
sentences = self._task_instructions[task] + sentences
|
397 |
else:
|
398 |
+
sentences = [self._task_instructions[task] + sentence for sentence in sentences]
|
399 |
return self.roberta.encode(
|
400 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
401 |
)
|
modeling_xlm_roberta.py
CHANGED
@@ -473,7 +473,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
473 |
normalize_embeddings: bool = True,
|
474 |
truncate_dim: Optional[int] = None,
|
475 |
adapter_mask: Optional[torch.Tensor] = None,
|
476 |
-
|
477 |
**tokenizer_kwargs,
|
478 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
479 |
"""
|
|
|
473 |
normalize_embeddings: bool = True,
|
474 |
truncate_dim: Optional[int] = None,
|
475 |
adapter_mask: Optional[torch.Tensor] = None,
|
476 |
+
task: Optional[str] = None,
|
477 |
**tokenizer_kwargs,
|
478 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
479 |
"""
|