Jackmin108
commited on
Commit
•
7c7eafb
1
Parent(s):
7815d41
feat: evaluation/encode
Browse filesSigned-off-by: Meow <ongjackm@gmail.com>
- embedding.py +2 -2
- mha.py +3 -3
- mlp.py +2 -2
- modeling_lora.py +11 -20
- modeling_xlm_roberta.py +4 -4
embedding.py
CHANGED
@@ -55,7 +55,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
55 |
for task_id in unique_tasks:
|
56 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
57 |
task_input_ids = input_ids[task_indices]
|
58 |
-
task_embeddings = self.word_embeddings(task_input_ids,
|
59 |
embeddings[task_indices] = task_embeddings
|
60 |
else:
|
61 |
embeddings = self.word_embeddings(input_ids)
|
@@ -73,7 +73,7 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
73 |
if adapter_mask is not None:
|
74 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
75 |
for task_id in unique_tasks:
|
76 |
-
task_token_type_embeddings = self.token_type_embeddings(token_type_ids,
|
77 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
78 |
embeddings[task_indices] = embeddings[task_indices] + task_token_type_embeddings
|
79 |
else:
|
|
|
55 |
for task_id in unique_tasks:
|
56 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
57 |
task_input_ids = input_ids[task_indices]
|
58 |
+
task_embeddings = self.word_embeddings(task_input_ids, task_id=task_id)
|
59 |
embeddings[task_indices] = task_embeddings
|
60 |
else:
|
61 |
embeddings = self.word_embeddings(input_ids)
|
|
|
73 |
if adapter_mask is not None:
|
74 |
unique_tasks = torch.unique(adapter_mask).tolist()
|
75 |
for task_id in unique_tasks:
|
76 |
+
task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_id=task_id)
|
77 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
78 |
embeddings[task_indices] = embeddings[task_indices] + task_token_type_embeddings
|
79 |
else:
|
mha.py
CHANGED
@@ -655,9 +655,9 @@ class MHA(nn.Module):
|
|
655 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
656 |
task_tensor = x[task_indices]
|
657 |
if not self.return_residual:
|
658 |
-
task_qkv = self.Wqkv(task_tensor,
|
659 |
else:
|
660 |
-
task_qkv, _ = self.Wqkv(task_tensor,
|
661 |
qkv[task_indices] = task_qkv
|
662 |
else:
|
663 |
if not self.return_residual:
|
@@ -759,7 +759,7 @@ class MHA(nn.Module):
|
|
759 |
for task_id in unique_tasks:
|
760 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
761 |
task_tensor = inp[task_indices]
|
762 |
-
task_out = self.out_proj(task_tensor,
|
763 |
out[task_indices] = task_out
|
764 |
else:
|
765 |
out = self.out_proj(inp)
|
|
|
655 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
656 |
task_tensor = x[task_indices]
|
657 |
if not self.return_residual:
|
658 |
+
task_qkv = self.Wqkv(task_tensor, task_id=task_id)
|
659 |
else:
|
660 |
+
task_qkv, _ = self.Wqkv(task_tensor, task_id=task_id, residual=True)
|
661 |
qkv[task_indices] = task_qkv
|
662 |
else:
|
663 |
if not self.return_residual:
|
|
|
759 |
for task_id in unique_tasks:
|
760 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
761 |
task_tensor = inp[task_indices]
|
762 |
+
task_out = self.out_proj(task_tensor, task_id=task_id)
|
763 |
out[task_indices] = task_out
|
764 |
else:
|
765 |
out = self.out_proj(inp)
|
mlp.py
CHANGED
@@ -56,7 +56,7 @@ class Mlp(nn.Module):
|
|
56 |
for task_id in unique_tasks:
|
57 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
58 |
task_tensor = x[task_indices]
|
59 |
-
task_y = self.fc1(task_tensor,
|
60 |
y[task_indices] = task_y
|
61 |
else:
|
62 |
y = self.fc1(x)
|
@@ -71,7 +71,7 @@ class Mlp(nn.Module):
|
|
71 |
for task_id in unique_tasks:
|
72 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
73 |
task_tensor = y[task_indices]
|
74 |
-
task_out = self.fc2(task_tensor,
|
75 |
out[task_indices] = task_out
|
76 |
else:
|
77 |
out = self.fc1(y)
|
|
|
56 |
for task_id in unique_tasks:
|
57 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
58 |
task_tensor = x[task_indices]
|
59 |
+
task_y = self.fc1(task_tensor, task_id=task_id)
|
60 |
y[task_indices] = task_y
|
61 |
else:
|
62 |
y = self.fc1(x)
|
|
|
71 |
for task_id in unique_tasks:
|
72 |
task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
73 |
task_tensor = y[task_indices]
|
74 |
+
task_out = self.fc2(task_tensor, task_id=task_id)
|
75 |
out[task_indices] = task_out
|
76 |
else:
|
77 |
out = self.fc1(y)
|
modeling_lora.py
CHANGED
@@ -161,7 +161,6 @@ class LoRAParametrization(nn.Module):
|
|
161 |
rank: int,
|
162 |
dropout_p: float,
|
163 |
alpha: float,
|
164 |
-
adaptation_map: dict,
|
165 |
):
|
166 |
if isinstance(layer, nn.Linear):
|
167 |
parametrize.register_parametrization(
|
@@ -176,14 +175,9 @@ class LoRAParametrization(nn.Module):
|
|
176 |
),
|
177 |
)
|
178 |
|
179 |
-
def new_forward(self, input,
|
180 |
-
if
|
181 |
-
|
182 |
-
else:
|
183 |
-
task_idx = task_type
|
184 |
-
|
185 |
-
if task_idx is not None:
|
186 |
-
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
187 |
else:
|
188 |
weights = self.weight
|
189 |
|
@@ -208,14 +202,9 @@ class LoRAParametrization(nn.Module):
|
|
208 |
),
|
209 |
)
|
210 |
|
211 |
-
def new_forward(self, input,
|
212 |
-
if
|
213 |
-
|
214 |
-
else:
|
215 |
-
task_idx = task_type
|
216 |
-
|
217 |
-
if task_idx is not None:
|
218 |
-
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
219 |
else:
|
220 |
weights = self.weight
|
221 |
|
@@ -325,7 +314,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
325 |
rank=rank,
|
326 |
dropout_p=dropout_p,
|
327 |
alpha=alpha,
|
328 |
-
adaptation_map=self._adaptation_map,
|
329 |
)
|
330 |
)
|
331 |
|
@@ -348,6 +336,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
348 |
@torch.inference_mode()
|
349 |
def encode(
|
350 |
self,
|
|
|
351 |
*args,
|
352 |
task_type: Optional[str] = None,
|
353 |
**kwargs,
|
@@ -366,5 +355,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
366 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
367 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
368 |
)
|
369 |
-
|
370 |
-
|
|
|
|
|
|
161 |
rank: int,
|
162 |
dropout_p: float,
|
163 |
alpha: float,
|
|
|
164 |
):
|
165 |
if isinstance(layer, nn.Linear):
|
166 |
parametrize.register_parametrization(
|
|
|
175 |
),
|
176 |
)
|
177 |
|
178 |
+
def new_forward(self, input, task_id=None, residual=False):
|
179 |
+
if task_id is not None:
|
180 |
+
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_id)
|
|
|
|
|
|
|
|
|
|
|
181 |
else:
|
182 |
weights = self.weight
|
183 |
|
|
|
202 |
),
|
203 |
)
|
204 |
|
205 |
+
def new_forward(self, input, task_id=None):
|
206 |
+
if task_id is not None:
|
207 |
+
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_id)
|
|
|
|
|
|
|
|
|
|
|
208 |
else:
|
209 |
weights = self.weight
|
210 |
|
|
|
314 |
rank=rank,
|
315 |
dropout_p=dropout_p,
|
316 |
alpha=alpha,
|
|
|
317 |
)
|
318 |
)
|
319 |
|
|
|
336 |
@torch.inference_mode()
|
337 |
def encode(
|
338 |
self,
|
339 |
+
sentences: Union[str, List[str]],
|
340 |
*args,
|
341 |
task_type: Optional[str] = None,
|
342 |
**kwargs,
|
|
|
355 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
356 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
357 |
)
|
358 |
+
task_id = self._adaptation_map[task_type]
|
359 |
+
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
360 |
+
adapter_mask = torch.full((num_examples,), task_id, dtype=torch.int32)
|
361 |
+
return self.roberta.encode(sentences, *args, adapter_mask=adapter_mask, **kwargs)
|
modeling_xlm_roberta.py
CHANGED
@@ -321,7 +321,7 @@ class XLMRobertaPooler(nn.Module):
|
|
321 |
for task_id in unique_tasks:
|
322 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
323 |
task_first_token_tensor = first_token_tensor[task_indices]
|
324 |
-
task_pooled_output = self.dense(task_first_token_tensor,
|
325 |
pooled_output[task_indices] = task_pooled_output
|
326 |
else:
|
327 |
pooled_output = self.dense(first_token_tensor)
|
@@ -464,7 +464,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
464 |
device: Optional[torch.device] = None,
|
465 |
normalize_embeddings: bool = False,
|
466 |
truncate_dim: Optional[int] = None,
|
467 |
-
|
468 |
**tokenizer_kwargs,
|
469 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
470 |
"""
|
@@ -549,14 +549,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
549 |
)
|
550 |
else:
|
551 |
range_iter = range(0, len(sentences), batch_size)
|
552 |
-
|
553 |
for i in range_iter:
|
554 |
encoded_input = self.tokenizer(
|
555 |
sentences[i : i + batch_size],
|
556 |
return_tensors='pt',
|
557 |
**tokenizer_kwargs,
|
558 |
).to(self.device)
|
559 |
-
token_embs = self.forward(**encoded_input, **
|
560 |
|
561 |
# Accumulate in fp32 to avoid overflow
|
562 |
token_embs = token_embs.float()
|
|
|
321 |
for task_id in unique_tasks:
|
322 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
323 |
task_first_token_tensor = first_token_tensor[task_indices]
|
324 |
+
task_pooled_output = self.dense(task_first_token_tensor, task_id=task_id)
|
325 |
pooled_output[task_indices] = task_pooled_output
|
326 |
else:
|
327 |
pooled_output = self.dense(first_token_tensor)
|
|
|
464 |
device: Optional[torch.device] = None,
|
465 |
normalize_embeddings: bool = False,
|
466 |
truncate_dim: Optional[int] = None,
|
467 |
+
adapter_mask: Optional[torch.Tensor] = None,
|
468 |
**tokenizer_kwargs,
|
469 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
470 |
"""
|
|
|
549 |
)
|
550 |
else:
|
551 |
range_iter = range(0, len(sentences), batch_size)
|
552 |
+
lora_arguments = {'adapter_mask': adapter_mask} if adapter_mask is not None else {}
|
553 |
for i in range_iter:
|
554 |
encoded_input = self.tokenizer(
|
555 |
sentences[i : i + batch_size],
|
556 |
return_tensors='pt',
|
557 |
**tokenizer_kwargs,
|
558 |
).to(self.device)
|
559 |
+
token_embs = self.forward(**encoded_input, **lora_arguments)[0]
|
560 |
|
561 |
# Accumulate in fp32 to avoid overflow
|
562 |
token_embs = token_embs.float()
|