jupyterjazz commited on
Commit
6cc0f51
1 Parent(s): 4d09ca8

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (4) hide show
  1. embedding.py +2 -1
  2. mha.py +4 -3
  3. modeling_lora.py +44 -5
  4. modeling_xlm_roberta.py +1 -1
embedding.py CHANGED
@@ -47,7 +47,8 @@ class XLMRobertaEmbeddings(nn.Module):
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
- embeddings = self.word_embeddings(input_ids)
 
51
  if self.max_position_embeddings > 0:
52
  if position_ids is None:
53
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
 
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
+ print('input shape', input_ids.shape)
51
+ embeddings = self.word_embeddings(input_ids, task='sts')
52
  if self.max_position_embeddings > 0:
53
  if position_ids is None:
54
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
mha.py CHANGED
@@ -340,8 +340,8 @@ class CrossAttention(nn.Module):
340
  class LinearResidual(nn.Linear):
341
  """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
342
 
343
- def forward(self, input: torch.Tensor) -> torch.Tensor:
344
- return super().forward(input), input
345
 
346
 
347
  def _update_kv_cache(kv, inference_params, layer_idx):
@@ -450,6 +450,7 @@ class MHA(nn.Module):
450
 
451
  if fused_bias_fc and FusedDense is None:
452
  raise ImportError("fused_dense is not installed")
 
453
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
454
  linear_resid_cls = (
455
  LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
@@ -646,7 +647,7 @@ class MHA(nn.Module):
646
  if not self.return_residual:
647
  qkv = self.Wqkv(x)
648
  else:
649
- qkv, x = self.Wqkv(x)
650
  if self.dwconv:
651
  qkv = rearrange(
652
  self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
340
  class LinearResidual(nn.Linear):
341
  """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
342
 
343
+ def forward(self, input: torch.Tensor, task=None) -> torch.Tensor:
344
+ return super().forward(input, task=task), input
345
 
346
 
347
  def _update_kv_cache(kv, inference_params, layer_idx):
 
450
 
451
  if fused_bias_fc and FusedDense is None:
452
  raise ImportError("fused_dense is not installed")
453
+ print('is this true', fused_bias_fc)
454
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
455
  linear_resid_cls = (
456
  LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
 
647
  if not self.return_residual:
648
  qkv = self.Wqkv(x)
649
  else:
650
+ qkv, x = self.Wqkv(x, task='sts')
651
  if self.dwconv:
652
  qkv = rearrange(
653
  self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
modeling_lora.py CHANGED
@@ -98,15 +98,15 @@ class LoRAParametrization(nn.Module):
98
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
99
  return A * self.lora_dropout(self.lora_dropout_mask)
100
 
101
- def lora_forward(self, X):
102
- assert self.current_task is not None
103
  return (
104
  X
105
  + torch.matmul(
106
  *self.swap(
107
  (
108
- self.lora_B[self.current_task],
109
- self.dropout_fn(self.lora_A[self.current_task]),
110
  )
111
  )
112
  ).view(X.shape)
@@ -114,7 +114,10 @@ class LoRAParametrization(nn.Module):
114
  )
115
 
116
  def forward(self, X):
117
- return self.forward_fn(X)
 
 
 
118
 
119
  @property
120
  def current_task(self):
@@ -178,6 +181,7 @@ class LoRAParametrization(nn.Module):
178
  rank: int,
179
  dropout_p: float,
180
  alpha: float,
 
181
  ):
182
  if isinstance(layer, nn.Linear):
183
  parametrize.register_parametrization(
@@ -191,6 +195,16 @@ class LoRAParametrization(nn.Module):
191
  alpha=alpha,
192
  ),
193
  )
 
 
 
 
 
 
 
 
 
 
194
  elif isinstance(layer, nn.Embedding):
195
  parametrize.register_parametrization(
196
  layer,
@@ -203,6 +217,23 @@ class LoRAParametrization(nn.Module):
203
  alpha=alpha,
204
  ),
205
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  @staticmethod
208
  def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
@@ -247,6 +278,13 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
247
  self._task_idx = None
248
  # By default, disable LoRA until it's specified which adapter/task to use
249
  self.current_task = None
 
 
 
 
 
 
 
250
 
251
  @property
252
  def main_params_trainable(self):
@@ -300,6 +338,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
300
  rank=rank,
301
  dropout_p=dropout_p,
302
  alpha=alpha,
 
303
  )
304
  )
305
 
 
98
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
99
  return A * self.lora_dropout(self.lora_dropout_mask)
100
 
101
+ def lora_forward(self, X, current_task=None):
102
+ print('lora input shape', X.shape)
103
  return (
104
  X
105
  + torch.matmul(
106
  *self.swap(
107
  (
108
+ self.lora_B[current_task],
109
+ self.dropout_fn(self.lora_A[current_task]),
110
  )
111
  )
112
  ).view(X.shape)
 
114
  )
115
 
116
  def forward(self, X):
117
+ print('forward input shape', X.shape, X)
118
+ out = self.forward_fn(X)
119
+ print(out.shape)
120
+ return out
121
 
122
  @property
123
  def current_task(self):
 
181
  rank: int,
182
  dropout_p: float,
183
  alpha: float,
184
+ adaptation_map: dict,
185
  ):
186
  if isinstance(layer, nn.Linear):
187
  parametrize.register_parametrization(
 
195
  alpha=alpha,
196
  ),
197
  )
198
+ original_forward = layer.forward
199
+
200
+ def new_forward(self, input, task):
201
+ print('an aq mitxari aba')
202
+ output = original_forward(input, task=task)
203
+ weight = self.parametrizations.weight(self.weight, task)
204
+ return nn.functional.linear(input, weight, self.bias)
205
+
206
+ layer.forward = new_forward.__get__(layer, layer.__class__)
207
+
208
  elif isinstance(layer, nn.Embedding):
209
  parametrize.register_parametrization(
210
  layer,
 
217
  alpha=alpha,
218
  ),
219
  )
220
+ original_forward = layer.forward
221
+
222
+ def new_forward(self, input, task):
223
+ print('input here', input, input.shape)
224
+ print('func', original_forward)
225
+ # original_forward['parametrizations'] = None
226
+ # print('funcc', original_forward.__dict__)
227
+ output = original_forward(input)
228
+ print(output.shape, 'output shape')
229
+ task_idx = adaptation_map[task] if task else None
230
+ if task_idx:
231
+ output = self.parametrizations.weight[0].lora_forward(output, current_task=task_idx)
232
+ print('thats it')
233
+ return output
234
+
235
+ layer.forward = new_forward.__get__(layer, layer.__class__)
236
+
237
 
238
  @staticmethod
239
  def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
 
278
  self._task_idx = None
279
  # By default, disable LoRA until it's specified which adapter/task to use
280
  self.current_task = None
281
+ for name, param in super().named_parameters():
282
+ if name == 'roberta.encoder.layers.22.mixer.Wqkv.parametrizations.weight.0.lora_A':
283
+ print('A0', param[0])
284
+ print('A1', param[1])
285
+ if name == 'roberta.encoder.layers.22.mixer.Wqkv.parametrizations.weight.0.lora_B':
286
+ print('B0', param[0])
287
+ print('B1', param[1])
288
 
289
  @property
290
  def main_params_trainable(self):
 
338
  rank=rank,
339
  dropout_p=dropout_p,
340
  alpha=alpha,
341
+ adaptation_map=self._adaptation_map,
342
  )
343
  )
344
 
modeling_xlm_roberta.py CHANGED
@@ -204,7 +204,7 @@ class XLMRobertaEncoder(nn.Module):
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
- def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
 
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool