Markus28 commited on
Commit
6aad619
1 Parent(s): 7771ce3

feat: use property instead of setter

Browse files
Files changed (1) hide show
  1. modeling_lora.py +8 -4
modeling_lora.py CHANGED
@@ -186,7 +186,7 @@ class BertLoRA(BertPreTrainedModel):
186
  for name, param in super().named_parameters():
187
  if "lora" not in name:
188
  param.requires_grad_(False)
189
- self.select_task(0)
190
 
191
  @classmethod
192
  def from_bert(cls, *args, num_adaptions=1, **kwargs):
@@ -194,7 +194,6 @@ class BertLoRA(BertPreTrainedModel):
194
  config = JinaBertConfig.from_pretrained(*args, **kwargs)
195
  return cls(config, bert=bert, num_adaptions=num_adaptions)
196
 
197
-
198
  @classmethod
199
  def from_pretrained(
200
  cls,
@@ -213,7 +212,6 @@ class BertLoRA(BertPreTrainedModel):
213
  # TODO: choose between from_bert and super().from_pretrained
214
  return cls.from_bert(pretrained_model_name_or_path)
215
 
216
-
217
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
218
  self.apply(
219
  partial(
@@ -225,7 +223,13 @@ class BertLoRA(BertPreTrainedModel):
225
  )
226
  )
227
 
228
- def select_task(self, task_idx: Union[None, int]):
 
 
 
 
 
 
229
  self.apply(
230
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
231
  )
 
186
  for name, param in super().named_parameters():
187
  if "lora" not in name:
188
  param.requires_grad_(False)
189
+ self.current_task = 0
190
 
191
  @classmethod
192
  def from_bert(cls, *args, num_adaptions=1, **kwargs):
 
194
  config = JinaBertConfig.from_pretrained(*args, **kwargs)
195
  return cls(config, bert=bert, num_adaptions=num_adaptions)
196
 
 
197
  @classmethod
198
  def from_pretrained(
199
  cls,
 
212
  # TODO: choose between from_bert and super().from_pretrained
213
  return cls.from_bert(pretrained_model_name_or_path)
214
 
 
215
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
216
  self.apply(
217
  partial(
 
223
  )
224
  )
225
 
226
+ @property
227
+ def current_task(self):
228
+ return self._task_idx
229
+
230
+ @current_task.setter
231
+ def current_task(self, task_idx: Union[None, int]):
232
+ self._task_idx = task_idx
233
  self.apply(
234
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
235
  )