jupyterjazz commited on
Commit
e0ea168
1 Parent(s): f9b3adb

refactor: lora

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

configuration_xlm_roberta.py CHANGED
@@ -22,7 +22,10 @@ class XLMRobertaFlashConfig(PretrainedConfig):
22
  position_embedding_type="absolute",
23
  use_cache=True,
24
  classifier_dropout=None,
25
- num_loras=1,
 
 
 
26
  load_trained_adapters=False,
27
  use_flash_attn=True,
28
  torch_dtype=None,
@@ -47,8 +50,11 @@ class XLMRobertaFlashConfig(PretrainedConfig):
47
  self.position_embedding_type = position_embedding_type
48
  self.use_cache = use_cache
49
  self.classifier_dropout = classifier_dropout
50
- self.num_loras = num_loras
51
  self.load_trained_adapters = load_trained_adapters
 
 
 
 
52
  self.use_flash_attn = use_flash_attn
53
  self.emb_pooler = emb_pooler
54
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
 
22
  position_embedding_type="absolute",
23
  use_cache=True,
24
  classifier_dropout=None,
25
+ lora_adaptations=None,
26
+ lora_rank=4,
27
+ lora_dropout_p=0.0,
28
+ lora_alpha=1,
29
  load_trained_adapters=False,
30
  use_flash_attn=True,
31
  torch_dtype=None,
 
50
  self.position_embedding_type = position_embedding_type
51
  self.use_cache = use_cache
52
  self.classifier_dropout = classifier_dropout
 
53
  self.load_trained_adapters = load_trained_adapters
54
+ self.lora_adaptations = lora_adaptations
55
+ self.lora_rank = lora_rank
56
+ self.lora_dropout_p = lora_dropout_p
57
+ self.lora_alpha = lora_alpha
58
  self.use_flash_attn = use_flash_attn
59
  self.emb_pooler = emb_pooler
60
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
modeling_lora.py CHANGED
@@ -9,14 +9,18 @@ from torch import nn
9
  from torch.nn import Parameter
10
  from transformers import PretrainedConfig
11
 
12
- from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaPreTrainedModel, XLMRobertaFlashConfig
 
 
 
 
13
 
14
 
15
  def initialized_weights(
16
- shape: Tuple[int], num_adaptions: int, init: str = "kaiming"
17
  ) -> torch.Tensor:
18
  weight_data = []
19
- for _ in range(num_adaptions):
20
  new_adaption = torch.zeros(shape)
21
  if init == "kaiming":
22
  nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
@@ -45,15 +49,16 @@ class LoRAParametrization(nn.Module):
45
  WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
46
  SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
47
  """
 
48
  def __init__(
49
  self,
50
  fan_in: int,
51
  fan_out: int,
52
  layer_type: str = "linear",
53
- num_adaptions: int = 1,
54
  rank: int = 4,
55
- lora_dropout_p: float = 0.0,
56
- lora_alpha: float = 1,
57
  ):
58
  super().__init__()
59
  # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
@@ -63,25 +68,23 @@ class LoRAParametrization(nn.Module):
63
 
64
  if layer_type == "linear":
65
  self.lora_A = nn.Parameter(
66
- initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
67
  )
68
- self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
69
  elif layer_type == "embedding":
70
- self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
71
  self.lora_B = nn.Parameter(
72
  initialized_weights(
73
- (rank, fan_out), num_adaptions=num_adaptions, init="normal"
74
  )
75
  )
76
  else:
77
  raise NotImplementedError
78
 
79
- self.lora_alpha, self.rank = lora_alpha, rank
80
- self.scaling = lora_alpha / rank
81
- self.lora_dropout = (
82
- nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
83
- )
84
- self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
85
  self.register_buffer(
86
  "lora_dropout_mask",
87
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
@@ -128,42 +131,52 @@ class LoRAParametrization(nn.Module):
128
  def from_linear(
129
  cls,
130
  layer: nn.Module,
131
- num_adaptions: int = 1,
132
- rank: int = 4,
133
- lora_dropout_p: float = 0.0,
134
- lora_alpha: int = 1,
135
  ):
136
  assert isinstance(layer, nn.Linear)
137
  fan_out, fan_in = layer.weight.shape
138
  return cls(
139
  fan_in,
140
  fan_out,
141
- num_adaptions=num_adaptions,
142
  layer_type="linear",
143
  rank=rank,
144
- lora_dropout_p=lora_dropout_p,
145
- lora_alpha=lora_alpha,
146
  )
147
 
148
  @classmethod
149
  def from_embedding(
150
- cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
 
 
 
 
 
151
  ):
152
  assert isinstance(layer, nn.Embedding)
153
  fan_in, fan_out = layer.weight.shape
154
  return cls(
155
  fan_in,
156
  fan_out,
157
- num_adaptions=num_adaptions,
158
  layer_type="embedding",
159
  rank=rank,
160
- lora_dropout_p=lora_dropout_p,
161
- lora_alpha=lora_alpha,
162
  )
163
 
164
  @classmethod
165
  def add_to_layer(
166
- cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
 
 
 
 
 
167
  ):
168
  if isinstance(layer, nn.Linear):
169
  parametrize.register_parametrization(
@@ -171,10 +184,10 @@ class LoRAParametrization(nn.Module):
171
  "weight",
172
  cls.from_linear(
173
  layer,
174
- num_adaptions=num_adaptions,
175
  rank=rank,
176
- lora_dropout_p=lora_dropout_p,
177
- lora_alpha=lora_alpha,
178
  ),
179
  )
180
  elif isinstance(layer, nn.Embedding):
@@ -183,10 +196,10 @@ class LoRAParametrization(nn.Module):
183
  "weight",
184
  cls.from_embedding(
185
  layer,
186
- num_adaptions=num_adaptions,
187
  rank=rank,
188
- lora_dropout_p=lora_dropout_p,
189
- lora_alpha=lora_alpha,
190
  ),
191
  )
192
 
@@ -195,15 +208,14 @@ class LoRAParametrization(nn.Module):
195
  if isinstance(layer, LoRAParametrization):
196
  layer.current_task = task_idx
197
 
198
- @staticmethod
199
- def merge_lora_into_layer(layer: nn.Module):
200
- if hasattr(layer, "parametrizations"):
201
- for attr_name in layer.parametrizations.keys():
202
- parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=True)
203
-
204
 
205
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
206
- def __init__(self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None, add_pooling_layer=True):
 
 
 
 
 
207
  super().__init__(config)
208
 
209
  if roberta is None:
@@ -211,10 +223,17 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
211
  else:
212
  self.roberta = roberta
213
 
214
- self._is_merged = False
215
- self._num_adaptions = config.num_loras
216
- self._register_lora(self._num_adaptions)
 
217
 
 
 
 
 
 
 
218
  self.main_params_trainable = False
219
  self._task_idx = None
220
  # By default, we select the first LoRA
@@ -237,13 +256,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
237
  if "lora" not in name:
238
  param.requires_grad_(val)
239
 
240
- def merge_lora(self):
241
- """Merges currently selected LoRA into main weights."""
242
- if self._is_merged:
243
- raise Exception('LoRA has already been merged, cannot merge again')
244
- self._is_merged = True
245
- self.apply(LoRAParametrization.merge_lora_into_layer)
246
-
247
  @classmethod
248
  def from_pretrained(
249
  cls,
@@ -259,31 +271,33 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
259
  use_safetensors: bool = None,
260
  **kwargs,
261
  ):
262
- config = XLMRobertaFlashConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
 
 
263
  if config.load_trained_adapters:
264
  return super().from_pretrained(
265
- pretrained_model_name_or_path,
266
- *model_args,
267
- **kwargs
268
  )
269
  else:
270
- roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
 
 
271
  return cls(config, roberta=roberta)
272
 
273
- def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
274
  self.apply(
275
  partial(
276
  LoRAParametrization.add_to_layer,
277
- num_adaptions=num_adaptions,
278
  rank=rank,
279
- lora_dropout_p=lora_dropout_p,
280
- lora_alpha=lora_alpha,
281
  )
282
  )
283
 
284
  @property
285
  def current_task(self):
286
- """ Which LoRA is currently selected
287
  :return: Integer or None (when LoRA is disabled)
288
  """
289
  return self._task_idx
@@ -296,9 +310,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
296
  :param task_idx: Which LoRA to use
297
  :return:
298
  """
299
- if self._is_merged:
300
- raise Exception('LoRA has been merged, cannot select new task')
301
- assert task_idx is None or 0 <= task_idx < self._num_adaptions
302
  if self._task_idx != task_idx:
303
  # In this case, we need to update the LoRAs everywhere
304
  self._task_idx = task_idx
@@ -306,9 +318,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
306
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
307
  )
308
 
309
- def forward(self, *args, current_task: Union[None, int] = -1, **kwargs):
310
- if current_task is None or current_task >= 0:
311
- self.current_task = current_task
312
  return self.roberta(*args, **kwargs)
313
 
314
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
 
9
  from torch.nn import Parameter
10
  from transformers import PretrainedConfig
11
 
12
+ from .modeling_xlm_roberta import (
13
+ XLMRobertaFlashConfig,
14
+ XLMRobertaModel,
15
+ XLMRobertaPreTrainedModel,
16
+ )
17
 
18
 
19
  def initialized_weights(
20
+ shape: Tuple[int], num_adaptations: int, init: str = "kaiming"
21
  ) -> torch.Tensor:
22
  weight_data = []
23
+ for _ in range(num_adaptations):
24
  new_adaption = torch.zeros(shape)
25
  if init == "kaiming":
26
  nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
 
49
  WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
50
  SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
51
  """
52
+
53
  def __init__(
54
  self,
55
  fan_in: int,
56
  fan_out: int,
57
  layer_type: str = "linear",
58
+ num_adaptations: int = 1,
59
  rank: int = 4,
60
+ dropout_p: float = 0.0,
61
+ alpha: float = 1,
62
  ):
63
  super().__init__()
64
  # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
 
68
 
69
  if layer_type == "linear":
70
  self.lora_A = nn.Parameter(
71
+ initialized_weights((rank, fan_in), num_adaptations, init="kaiming")
72
  )
73
+ self.lora_B = nn.Parameter(torch.zeros((num_adaptations, fan_out, rank)))
74
  elif layer_type == "embedding":
75
+ self.lora_A = nn.Parameter(torch.zeros((num_adaptations, fan_in, rank)))
76
  self.lora_B = nn.Parameter(
77
  initialized_weights(
78
+ (rank, fan_out), num_adaptations=num_adaptations, init="normal"
79
  )
80
  )
81
  else:
82
  raise NotImplementedError
83
 
84
+ self.lora_alpha, self.rank = alpha, rank
85
+ self.scaling = alpha / rank
86
+ self.lora_dropout = nn.Dropout(p=dropout_p) if dropout_p > 0 else lambda x: x
87
+ self.dropout_fn = self._dropout if dropout_p > 0 else lambda x: x
 
 
88
  self.register_buffer(
89
  "lora_dropout_mask",
90
  torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
 
131
  def from_linear(
132
  cls,
133
  layer: nn.Module,
134
+ num_adaptations: int,
135
+ rank: int,
136
+ dropout_p: float,
137
+ alpha: float,
138
  ):
139
  assert isinstance(layer, nn.Linear)
140
  fan_out, fan_in = layer.weight.shape
141
  return cls(
142
  fan_in,
143
  fan_out,
144
+ num_adaptations=num_adaptations,
145
  layer_type="linear",
146
  rank=rank,
147
+ dropout_p=dropout_p,
148
+ alpha=alpha,
149
  )
150
 
151
  @classmethod
152
  def from_embedding(
153
+ cls,
154
+ layer: nn.Module,
155
+ num_adaptations: int,
156
+ rank: int,
157
+ dropout_p: float,
158
+ alpha: float,
159
  ):
160
  assert isinstance(layer, nn.Embedding)
161
  fan_in, fan_out = layer.weight.shape
162
  return cls(
163
  fan_in,
164
  fan_out,
165
+ num_adaptations=num_adaptations,
166
  layer_type="embedding",
167
  rank=rank,
168
+ dropout_p=dropout_p,
169
+ alpha=alpha,
170
  )
171
 
172
  @classmethod
173
  def add_to_layer(
174
+ cls,
175
+ layer: nn.Module,
176
+ num_adaptations: int,
177
+ rank: int,
178
+ dropout_p: float,
179
+ alpha: float,
180
  ):
181
  if isinstance(layer, nn.Linear):
182
  parametrize.register_parametrization(
 
184
  "weight",
185
  cls.from_linear(
186
  layer,
187
+ num_adaptations=num_adaptations,
188
  rank=rank,
189
+ dropout_p=dropout_p,
190
+ alpha=alpha,
191
  ),
192
  )
193
  elif isinstance(layer, nn.Embedding):
 
196
  "weight",
197
  cls.from_embedding(
198
  layer,
199
+ num_adaptations=num_adaptations,
200
  rank=rank,
201
+ dropout_p=dropout_p,
202
+ alpha=alpha,
203
  ),
204
  )
205
 
 
208
  if isinstance(layer, LoRAParametrization):
209
  layer.current_task = task_idx
210
 
 
 
 
 
 
 
211
 
212
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
213
+ def __init__(
214
+ self,
215
+ config: XLMRobertaFlashConfig,
216
+ roberta: Optional[XLMRobertaModel] = None,
217
+ add_pooling_layer=True,
218
+ ):
219
  super().__init__(config)
220
 
221
  if roberta is None:
 
223
  else:
224
  self.roberta = roberta
225
 
226
+ self._num_adaptations = len(config.lora_adaptations)
227
+ self._rank = config.lora_rank
228
+ self._dropout_p = config.lora_dropout_p
229
+ self._alpha = config.lora_alpha
230
 
231
+ self._register_lora(
232
+ num_adaptations=self._num_adaptations,
233
+ rank=self._rank,
234
+ dropout_p=self._dropout_p,
235
+ alpha=self._alpha,
236
+ )
237
  self.main_params_trainable = False
238
  self._task_idx = None
239
  # By default, we select the first LoRA
 
256
  if "lora" not in name:
257
  param.requires_grad_(val)
258
 
 
 
 
 
 
 
 
259
  @classmethod
260
  def from_pretrained(
261
  cls,
 
271
  use_safetensors: bool = None,
272
  **kwargs,
273
  ):
274
+ config = XLMRobertaFlashConfig.from_pretrained(
275
+ pretrained_model_name_or_path, *model_args, **kwargs
276
+ )
277
  if config.load_trained_adapters:
278
  return super().from_pretrained(
279
+ pretrained_model_name_or_path, *model_args, **kwargs
 
 
280
  )
281
  else:
282
+ roberta = XLMRobertaModel.from_pretrained(
283
+ pretrained_model_name_or_path, *model_args, **kwargs
284
+ )
285
  return cls(config, roberta=roberta)
286
 
287
+ def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
288
  self.apply(
289
  partial(
290
  LoRAParametrization.add_to_layer,
291
+ num_adaptations=num_adaptations,
292
  rank=rank,
293
+ dropout_p=dropout_p,
294
+ alpha=alpha,
295
  )
296
  )
297
 
298
  @property
299
  def current_task(self):
300
+ """Which LoRA is currently selected
301
  :return: Integer or None (when LoRA is disabled)
302
  """
303
  return self._task_idx
 
310
  :param task_idx: Which LoRA to use
311
  :return:
312
  """
313
+ assert task_idx is None or 0 <= task_idx < self._num_adaptations
 
 
314
  if self._task_idx != task_idx:
315
  # In this case, we need to update the LoRAs everywhere
316
  self._task_idx = task_idx
 
318
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
319
  )
320
 
321
+ def forward(self, *args, lora_adaptation: Union[None, int] = -1, **kwargs):
322
+ if lora_adaptation is None or lora_adaptation >= 0:
323
+ self.current_task = lora_adaptation
324
  return self.roberta(*args, **kwargs)
325
 
326
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
modeling_xlm_roberta.py CHANGED
@@ -452,6 +452,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
452
  convert_to_tensor: bool = False,
453
  device: Optional[torch.device] = None,
454
  normalize_embeddings: bool = False,
 
455
  **tokenizer_kwargs,
456
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
457
  """
@@ -481,6 +482,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
481
  If set to true, returned vectors will have length 1. In that case, the
482
  faster dot-product (util.dot_score) instead of cosine similarity can
483
  be used.
 
 
 
 
 
 
484
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
485
  Keyword arguments for the tokenizer
486
  Returns:
@@ -518,6 +525,22 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
518
  if device is not None:
519
  self.to(device)
520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  permutation = np.argsort([-len(i) for i in sentences])
522
  inverse_permutation = np.argsort(permutation)
523
  sentences = [sentences[idx] for idx in permutation]
@@ -547,7 +570,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
547
  return_tensors='pt',
548
  **tokenizer_kwargs,
549
  ).to(self.device)
550
- token_embs = self.forward(**encoded_input)[0]
551
 
552
  # Accumulate in fp32 to avoid overflow
553
  token_embs = token_embs.float()
@@ -1253,4 +1276,4 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1253
  logits=logits,
1254
  hidden_states=outputs.hidden_states,
1255
  attentions=outputs.attentions,
1256
- )
 
452
  convert_to_tensor: bool = False,
453
  device: Optional[torch.device] = None,
454
  normalize_embeddings: bool = False,
455
+ task: Optional[str] = None,
456
  **tokenizer_kwargs,
457
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
458
  """
 
482
  If set to true, returned vectors will have length 1. In that case, the
483
  faster dot-product (util.dot_score) instead of cosine similarity can
484
  be used.
485
+ task(`str`, *optional*, defaults to None):
486
+ Specifies the task for which the encoding is intended. This
487
+ controls the use of specialized LoRA adapters that are tuned for specific tasks.
488
+ If provided, the corresponding LoRA adapter is enabled, enhancing the model's
489
+ performance for that task. If `None` or not provided, LoRA is disabled, and the
490
+ model uses its original, general-purpose weights.
491
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
492
  Keyword arguments for the tokenizer
493
  Returns:
 
525
  if device is not None:
526
  self.to(device)
527
 
528
+ lora_adapter_num = None
529
+ if self.config.lora_adaptations:
530
+ if task:
531
+ if task in self.config.lora_adaptations:
532
+ lora_adapter_num = self.config.lora_adaptations.index(task)
533
+ else:
534
+ raise ValueError(
535
+ f"Unsupported task '{task}'. "
536
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}.")
537
+ else:
538
+ logger.warning(
539
+ f"Task-specific embeddings are disabled. To enable, specify the `task` "
540
+ f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}"
541
+ )
542
+
543
+
544
  permutation = np.argsort([-len(i) for i in sentences])
545
  inverse_permutation = np.argsort(permutation)
546
  sentences = [sentences[idx] for idx in permutation]
 
570
  return_tensors='pt',
571
  **tokenizer_kwargs,
572
  ).to(self.device)
573
+ token_embs = self.forward(**encoded_input, lora_adaptation=lora_adapter_num)[0]
574
 
575
  # Accumulate in fp32 to avoid overflow
576
  token_embs = token_embs.float()
 
1276
  logits=logits,
1277
  hidden_states=outputs.hidden_states,
1278
  attentions=outputs.attentions,
1279
+ )