Markus28 commited on
Commit
702e6c9
1 Parent(s): 284316d

feat: Allow LoRA to be merged into weights (#12)

Browse files

- feat: added method to merge LoRA weights (a2b49b2c72b661d680e6fde184d5ef05a5ebc974)

Files changed (1) hide show
  1. modeling_lora.py +16 -0
modeling_lora.py CHANGED
@@ -199,6 +199,12 @@ class LoRAParametrization(nn.Module):
199
  if isinstance(layer, LoRAParametrization):
200
  layer.current_task = task_idx
201
 
 
 
 
 
 
 
202
 
203
  class BertLoRA(BertPreTrainedModel):
204
  def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True):
@@ -207,6 +213,7 @@ class BertLoRA(BertPreTrainedModel):
207
  self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
208
  else:
209
  self.bert = bert
 
210
  self._num_adaptions = config.num_loras
211
  self._register_lora(self._num_adaptions)
212
  self.main_params_trainable = False
@@ -230,6 +237,13 @@ class BertLoRA(BertPreTrainedModel):
230
  config = JinaBertConfig.from_pretrained(*args, **kwargs)
231
  return cls(config, bert=bert)
232
 
 
 
 
 
 
 
 
233
  @classmethod
234
  def from_pretrained(
235
  cls,
@@ -265,6 +279,8 @@ class BertLoRA(BertPreTrainedModel):
265
 
266
  @current_task.setter
267
  def current_task(self, task_idx: Union[None, int]):
 
 
268
  assert task_idx is None or 0 <= task_idx < self._num_adaptions
269
  if self._task_idx != task_idx:
270
  self._task_idx = task_idx
 
199
  if isinstance(layer, LoRAParametrization):
200
  layer.current_task = task_idx
201
 
202
+ @classmethod
203
+ def merge_lora_into_layer(cls, layer: nn.Module):
204
+ if hasattr(layer, "parametrizations"):
205
+ for attr_name in layer.parametrizations.keys():
206
+ parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=True)
207
+
208
 
209
  class BertLoRA(BertPreTrainedModel):
210
  def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True):
 
213
  self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
214
  else:
215
  self.bert = bert
216
+ self._is_merged = False
217
  self._num_adaptions = config.num_loras
218
  self._register_lora(self._num_adaptions)
219
  self.main_params_trainable = False
 
237
  config = JinaBertConfig.from_pretrained(*args, **kwargs)
238
  return cls(config, bert=bert)
239
 
240
+ def merge_lora(self):
241
+ """Merges currently selected LoRA into main weights."""
242
+ if self._is_merged:
243
+ raise Exception('LoRA has already been merged, cannot merge again')
244
+ self._is_merged = True
245
+ self.apply(LoRAParametrization.merge_lora_into_layer)
246
+
247
  @classmethod
248
  def from_pretrained(
249
  cls,
 
279
 
280
  @current_task.setter
281
  def current_task(self, task_idx: Union[None, int]):
282
+ if self._is_merged:
283
+ raise Exception('LoRA has been merged, cannot select new task')
284
  assert task_idx is None or 0 <= task_idx < self._num_adaptions
285
  if self._task_idx != task_idx:
286
  self._task_idx = task_idx