feat: add current_task to forward
Browse files- modeling_lora.py +3 -1
modeling_lora.py
CHANGED
@@ -259,7 +259,9 @@ class BertLoRA(BertPreTrainedModel):
|
|
259 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
260 |
)
|
261 |
|
262 |
-
def forward(self, *args, **kwargs):
|
|
|
|
|
263 |
return self.bert(*args, **kwargs)
|
264 |
|
265 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
|
|
259 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
260 |
)
|
261 |
|
262 |
+
def forward(self, *args, **kwargs, current_task: Union[None, int] = -1):
|
263 |
+
if current_task is None or current_task >= 0:
|
264 |
+
self.current_task = current_task
|
265 |
return self.bert(*args, **kwargs)
|
266 |
|
267 |
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|