jupyterjazz commited on
Commit
f9b3adb
1 Parent(s): 0bb73e5

support lora (#1)

Browse files

- feat: support lora (5ed05aac24feb06378f0e19b6ae1ad4b26fe613d)
- Update modeling_lora.py (c380b5a69a5f279bad937ce1d5de87248bf0adf2)
- feat: initialize models with or without adapters (79c3c9397232303f34ba232f60e5cb6856aaf3f0)
- chore: change num lora def value (f960115170389c74fe144f688be72b8821b8e35e)
- small change (6060bad367d8fb677124ffbb8ee3ee2d7849e352)
- feat: merge stuff (841b70fc561e8292098ed7e890a75ff352ab987b)

Files changed (2) hide show
  1. configuration_xlm_roberta.py +5 -1
  2. modeling_lora.py +325 -0
configuration_xlm_roberta.py CHANGED
@@ -22,6 +22,8 @@ class XLMRobertaFlashConfig(PretrainedConfig):
22
  position_embedding_type="absolute",
23
  use_cache=True,
24
  classifier_dropout=None,
 
 
25
  use_flash_attn=True,
26
  torch_dtype=None,
27
  emb_pooler=None,
@@ -29,6 +31,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
29
  ):
30
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
31
 
 
32
  self.vocab_size = vocab_size
33
  self.hidden_size = hidden_size
34
  self.num_hidden_layers = num_hidden_layers
@@ -44,10 +47,11 @@ class XLMRobertaFlashConfig(PretrainedConfig):
44
  self.position_embedding_type = position_embedding_type
45
  self.use_cache = use_cache
46
  self.classifier_dropout = classifier_dropout
 
 
47
  self.use_flash_attn = use_flash_attn
48
  self.emb_pooler = emb_pooler
49
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
50
  self.torch_dtype = getattr(torch, torch_dtype)
51
  else:
52
  self.torch_dtype = torch_dtype
53
-
 
22
  position_embedding_type="absolute",
23
  use_cache=True,
24
  classifier_dropout=None,
25
+ num_loras=1,
26
+ load_trained_adapters=False,
27
  use_flash_attn=True,
28
  torch_dtype=None,
29
  emb_pooler=None,
 
31
  ):
32
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
33
 
34
+
35
  self.vocab_size = vocab_size
36
  self.hidden_size = hidden_size
37
  self.num_hidden_layers = num_hidden_layers
 
47
  self.position_embedding_type = position_embedding_type
48
  self.use_cache = use_cache
49
  self.classifier_dropout = classifier_dropout
50
+ self.num_loras = num_loras
51
+ self.load_trained_adapters = load_trained_adapters
52
  self.use_flash_attn = use_flash_attn
53
  self.emb_pooler = emb_pooler
54
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
55
  self.torch_dtype = getattr(torch, torch_dtype)
56
  else:
57
  self.torch_dtype = torch_dtype
 
modeling_lora.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from functools import partial
4
+ from typing import Iterator, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn.utils.parametrize as parametrize
8
+ from torch import nn
9
+ from torch.nn import Parameter
10
+ from transformers import PretrainedConfig
11
+
12
+ from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaPreTrainedModel, XLMRobertaFlashConfig
13
+
14
+
15
+ def initialized_weights(
16
+ shape: Tuple[int], num_adaptions: int, init: str = "kaiming"
17
+ ) -> torch.Tensor:
18
+ weight_data = []
19
+ for _ in range(num_adaptions):
20
+ new_adaption = torch.zeros(shape)
21
+ if init == "kaiming":
22
+ nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
23
+ elif init == "normal":
24
+ nn.init.normal_(new_adaption)
25
+ else:
26
+ raise NotImplementedError
27
+ weight_data.append(new_adaption)
28
+ return torch.stack(weight_data, dim=0)
29
+
30
+
31
+ class LoRAParametrization(nn.Module):
32
+ """
33
+ This LoRA implementation was inspired by https://github.com/cccntu/minLoRA
34
+ The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy
35
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software
36
+ and associated documentation files (the "Software"), to deal in the Software without restriction,
37
+ including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
38
+ and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
39
+ subject to the following conditions:
40
+ The above copyright notice and this permission notice shall be included in all copies or substantial
41
+ portions of the Software.
42
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
43
+ LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
44
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
45
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
46
+ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
47
+ """
48
+ def __init__(
49
+ self,
50
+ fan_in: int,
51
+ fan_out: int,
52
+ layer_type: str = "linear",
53
+ num_adaptions: int = 1,
54
+ rank: int = 4,
55
+ lora_dropout_p: float = 0.0,
56
+ lora_alpha: float = 1,
57
+ ):
58
+ super().__init__()
59
+ # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
60
+ # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
61
+ fan_in_fan_out = layer_type == "embedding"
62
+ self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
63
+
64
+ if layer_type == "linear":
65
+ self.lora_A = nn.Parameter(
66
+ initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
67
+ )
68
+ self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
69
+ elif layer_type == "embedding":
70
+ self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
71
+ self.lora_B = nn.Parameter(
72
+ initialized_weights(
73
+ (rank, fan_out), num_adaptions=num_adaptions, init="normal"
74
+ )
75
+ )
76
+ else:
77
+ raise NotImplementedError
78
+
79
+ self.lora_alpha, self.rank = lora_alpha, rank
80
+ self.scaling = lora_alpha / rank
81
+ self.lora_dropout = (
82
+ nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
83
+ )
84
+ self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
85
+ self.register_buffer(
86
+ "lora_dropout_mask",
87
+ torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
88
+ persistent=False,
89
+ )
90
+ self.forward_fn = lambda x: x
91
+ self.current_task = None
92
+
93
+ def _dropout(self, A):
94
+ # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
95
+ return A * self.lora_dropout(self.lora_dropout_mask)
96
+
97
+ def lora_forward(self, X):
98
+ assert self.current_task is not None
99
+ return (
100
+ X
101
+ + torch.matmul(
102
+ *self.swap(
103
+ (
104
+ self.lora_B[self.current_task],
105
+ self.dropout_fn(self.lora_A[self.current_task]),
106
+ )
107
+ )
108
+ ).view(X.shape)
109
+ * self.scaling
110
+ )
111
+
112
+ def forward(self, X):
113
+ return self.forward_fn(X)
114
+
115
+ @property
116
+ def current_task(self):
117
+ return self._current_task
118
+
119
+ @current_task.setter
120
+ def current_task(self, task: Union[None, int]):
121
+ self._current_task = task
122
+ if task is None:
123
+ self.forward_fn = lambda x: x
124
+ else:
125
+ self.forward_fn = self.lora_forward
126
+
127
+ @classmethod
128
+ def from_linear(
129
+ cls,
130
+ layer: nn.Module,
131
+ num_adaptions: int = 1,
132
+ rank: int = 4,
133
+ lora_dropout_p: float = 0.0,
134
+ lora_alpha: int = 1,
135
+ ):
136
+ assert isinstance(layer, nn.Linear)
137
+ fan_out, fan_in = layer.weight.shape
138
+ return cls(
139
+ fan_in,
140
+ fan_out,
141
+ num_adaptions=num_adaptions,
142
+ layer_type="linear",
143
+ rank=rank,
144
+ lora_dropout_p=lora_dropout_p,
145
+ lora_alpha=lora_alpha,
146
+ )
147
+
148
+ @classmethod
149
+ def from_embedding(
150
+ cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
151
+ ):
152
+ assert isinstance(layer, nn.Embedding)
153
+ fan_in, fan_out = layer.weight.shape
154
+ return cls(
155
+ fan_in,
156
+ fan_out,
157
+ num_adaptions=num_adaptions,
158
+ layer_type="embedding",
159
+ rank=rank,
160
+ lora_dropout_p=lora_dropout_p,
161
+ lora_alpha=lora_alpha,
162
+ )
163
+
164
+ @classmethod
165
+ def add_to_layer(
166
+ cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
167
+ ):
168
+ if isinstance(layer, nn.Linear):
169
+ parametrize.register_parametrization(
170
+ layer,
171
+ "weight",
172
+ cls.from_linear(
173
+ layer,
174
+ num_adaptions=num_adaptions,
175
+ rank=rank,
176
+ lora_dropout_p=lora_dropout_p,
177
+ lora_alpha=lora_alpha,
178
+ ),
179
+ )
180
+ elif isinstance(layer, nn.Embedding):
181
+ parametrize.register_parametrization(
182
+ layer,
183
+ "weight",
184
+ cls.from_embedding(
185
+ layer,
186
+ num_adaptions=num_adaptions,
187
+ rank=rank,
188
+ lora_dropout_p=lora_dropout_p,
189
+ lora_alpha=lora_alpha,
190
+ ),
191
+ )
192
+
193
+ @staticmethod
194
+ def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
195
+ if isinstance(layer, LoRAParametrization):
196
+ layer.current_task = task_idx
197
+
198
+ @staticmethod
199
+ def merge_lora_into_layer(layer: nn.Module):
200
+ if hasattr(layer, "parametrizations"):
201
+ for attr_name in layer.parametrizations.keys():
202
+ parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=True)
203
+
204
+
205
+ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
206
+ def __init__(self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None, add_pooling_layer=True):
207
+ super().__init__(config)
208
+
209
+ if roberta is None:
210
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=add_pooling_layer)
211
+ else:
212
+ self.roberta = roberta
213
+
214
+ self._is_merged = False
215
+ self._num_adaptions = config.num_loras
216
+ self._register_lora(self._num_adaptions)
217
+
218
+ self.main_params_trainable = False
219
+ self._task_idx = None
220
+ # By default, we select the first LoRA
221
+ self.current_task = 0
222
+
223
+ @property
224
+ def main_params_trainable(self):
225
+ return self._main_params_trainable
226
+
227
+ @main_params_trainable.setter
228
+ def main_params_trainable(self, val: bool):
229
+ """Whether the main parameters (i.e. those that are not LoRA) should be trainable.
230
+ This method sets the `requires_grad_` attribute of the main weights
231
+ and controls which parameters are returned in `self.parameters()`.
232
+ :param val: Whether or not to make the parameters trainable.
233
+ :return: None
234
+ """
235
+ self._main_params_trainable = val
236
+ for name, param in super().named_parameters():
237
+ if "lora" not in name:
238
+ param.requires_grad_(val)
239
+
240
+ def merge_lora(self):
241
+ """Merges currently selected LoRA into main weights."""
242
+ if self._is_merged:
243
+ raise Exception('LoRA has already been merged, cannot merge again')
244
+ self._is_merged = True
245
+ self.apply(LoRAParametrization.merge_lora_into_layer)
246
+
247
+ @classmethod
248
+ def from_pretrained(
249
+ cls,
250
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
251
+ *model_args,
252
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
253
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
254
+ ignore_mismatched_sizes: bool = False,
255
+ force_download: bool = False,
256
+ local_files_only: bool = False,
257
+ token: Optional[Union[str, bool]] = None,
258
+ revision: str = "main",
259
+ use_safetensors: bool = None,
260
+ **kwargs,
261
+ ):
262
+ config = XLMRobertaFlashConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
263
+ if config.load_trained_adapters:
264
+ return super().from_pretrained(
265
+ pretrained_model_name_or_path,
266
+ *model_args,
267
+ **kwargs
268
+ )
269
+ else:
270
+ roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
271
+ return cls(config, roberta=roberta)
272
+
273
+ def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
274
+ self.apply(
275
+ partial(
276
+ LoRAParametrization.add_to_layer,
277
+ num_adaptions=num_adaptions,
278
+ rank=rank,
279
+ lora_dropout_p=lora_dropout_p,
280
+ lora_alpha=lora_alpha,
281
+ )
282
+ )
283
+
284
+ @property
285
+ def current_task(self):
286
+ """ Which LoRA is currently selected
287
+ :return: Integer or None (when LoRA is disabled)
288
+ """
289
+ return self._task_idx
290
+
291
+ @current_task.setter
292
+ def current_task(self, task_idx: Union[None, int]):
293
+ """Set the LoRA that is to be used.
294
+ The LoRA is specified by `task_idx`, which may be an integer >= 0,
295
+ indexing the available LoRAs. If it is None, no LoRA is used.
296
+ :param task_idx: Which LoRA to use
297
+ :return:
298
+ """
299
+ if self._is_merged:
300
+ raise Exception('LoRA has been merged, cannot select new task')
301
+ assert task_idx is None or 0 <= task_idx < self._num_adaptions
302
+ if self._task_idx != task_idx:
303
+ # In this case, we need to update the LoRAs everywhere
304
+ self._task_idx = task_idx
305
+ self.apply(
306
+ partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
307
+ )
308
+
309
+ def forward(self, *args, current_task: Union[None, int] = -1, **kwargs):
310
+ if current_task is None or current_task >= 0:
311
+ self.current_task = current_task
312
+ return self.roberta(*args, **kwargs)
313
+
314
+ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
315
+ for _, param in self.named_parameters(recurse=recurse):
316
+ yield param
317
+
318
+ def named_parameters(
319
+ self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
320
+ ) -> Iterator[Tuple[str, Parameter]]:
321
+ for name, param in super().named_parameters(
322
+ prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
323
+ ):
324
+ if "lora" in name or self.main_params_trainable:
325
+ yield name, param