Spaces:
Runtime error
Runtime error
Charles Lin
commited on
Commit
•
e56055d
1
Parent(s):
717a51e
Add algorithms from efk codebase
Browse files- algs/enn.py +114 -0
- algs/ft.py +121 -0
- algs/ke.py +312 -0
- algs/lu.py +90 -0
- algs/mend.py +297 -0
- algs/serac.py +452 -0
- app.py +1 -0
- editable_model.py +36 -0
- hooks.py +28 -0
- losses.py +181 -0
- metrics.py +135 -0
- models.py +196 -0
- nn.py +362 -0
- requirements.txt +6 -0
- utils.py +441 -0
algs/enn.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import higher
|
4 |
+
|
5 |
+
from editable_model import EditableModel
|
6 |
+
from utils import _logits
|
7 |
+
|
8 |
+
|
9 |
+
def fomaml_callback(all_grads):
|
10 |
+
return [g.detach() if g is not None else None for g in all_grads]
|
11 |
+
|
12 |
+
|
13 |
+
class ENN(EditableModel):
|
14 |
+
def __init__(self, model, config, model_constructor, edit_lrs=None, edit_loss_fn=None):
|
15 |
+
super().__init__(model, config, model_constructor)
|
16 |
+
|
17 |
+
if edit_lrs is None:
|
18 |
+
edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params)))
|
19 |
+
self.edit_lrs = edit_lrs
|
20 |
+
|
21 |
+
if edit_loss_fn is not None:
|
22 |
+
self.edit_loss_fn = edit_loss_fn
|
23 |
+
|
24 |
+
self.grad_callback = fomaml_callback if config.enn.first_order else lambda x: x
|
25 |
+
|
26 |
+
def outer_parameters(self, grouped=False):
|
27 |
+
extra_params = [self.edit_lrs]
|
28 |
+
if self.config.no_grad_layers is None:
|
29 |
+
model_params = self.model.parameters() if type(self.model.parameters()) == list else list(self.model.parameters())
|
30 |
+
else:
|
31 |
+
model_params = []
|
32 |
+
for m in self.model.modules():
|
33 |
+
if isinstance(m, nn.ModuleList):
|
34 |
+
model_params.extend(list(m[self.config.no_grad_layers:].parameters()))
|
35 |
+
|
36 |
+
if grouped:
|
37 |
+
return [
|
38 |
+
dict(params=model_params, lr=self.config.lr),
|
39 |
+
dict(params=extra_params, lr=self.config.lr_lr)
|
40 |
+
]
|
41 |
+
else:
|
42 |
+
return model_params + extra_params
|
43 |
+
|
44 |
+
def get_state_dict(self):
|
45 |
+
return self.state_dict()
|
46 |
+
|
47 |
+
def edit(self, batch, condition=None, detach_history=False):
|
48 |
+
opt = torch.optim.SGD([{"params": p, "lr": None}
|
49 |
+
for (n, p) in self.model.named_parameters() if n in self.config.model.inner_params])
|
50 |
+
with torch.enable_grad(), higher.innerloop_ctx(
|
51 |
+
self.model,
|
52 |
+
opt,
|
53 |
+
override={'lr': list(self.edit_lrs)},
|
54 |
+
copy_initial_weights=False,
|
55 |
+
track_higher_grads=self.training,
|
56 |
+
in_place=True
|
57 |
+
) as (fmodel, diffopt):
|
58 |
+
fmodel.eval()
|
59 |
+
for edit_step in range(self.config.enn.n_edit_steps):
|
60 |
+
output = _logits(fmodel(**batch))
|
61 |
+
loss = self.edit_loss_fn(output, batch["labels"])["nll"]
|
62 |
+
diffopt.step(loss, grad_callback=self.grad_callback)
|
63 |
+
|
64 |
+
if not detach_history:
|
65 |
+
model_edited = fmodel
|
66 |
+
else:
|
67 |
+
model_edited = self.model_constructor()
|
68 |
+
model_edited.load_state_dict(fmodel.state_dict())
|
69 |
+
model_edited.train(self.training)
|
70 |
+
|
71 |
+
return ENN(model_edited, self.config, self.model_constructor, edit_lrs=self.edit_lrs, edit_loss_fn=self.edit_loss_fn), {}
|
72 |
+
|
73 |
+
|
74 |
+
def test():
|
75 |
+
import transformers
|
76 |
+
import types
|
77 |
+
import copy
|
78 |
+
|
79 |
+
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
|
80 |
+
|
81 |
+
config = types.SimpleNamespace()
|
82 |
+
config.edit_lr = 0.1
|
83 |
+
config.model.inner_params = [
|
84 |
+
"transformer.h.9.mlp.c_fc.weight",
|
85 |
+
"transformer.h.9.mlp.c_proj.weight",
|
86 |
+
"transformer.h.10.mlp.c_fc.weight",
|
87 |
+
"transformer.h.10.mlp.c_proj.weight",
|
88 |
+
"transformer.h.11.mlp.c_fc.weight",
|
89 |
+
"transformer.h.11.mlp.c_proj.weight",
|
90 |
+
]
|
91 |
+
config.enn = {
|
92 |
+
"n_edit_steps": 2,
|
93 |
+
"first_order": False
|
94 |
+
}
|
95 |
+
|
96 |
+
enn = ENN(model, config, lambda: copy.deepcopy(model)).cuda()
|
97 |
+
|
98 |
+
x = torch.arange(100).view(5, 20).cuda() + 1000
|
99 |
+
|
100 |
+
edited = enn.edit(x, masks=torch.ones_like(x), labels=x)
|
101 |
+
|
102 |
+
orig_param = [p for (n, p) in enn.model.named_parameters() if n == config.model.inner_params[-1]][0]
|
103 |
+
edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]
|
104 |
+
|
105 |
+
print((orig_param - edited_param).abs().max())
|
106 |
+
edited.eval()
|
107 |
+
print(enn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"])
|
108 |
+
edited.edit_loss_fn(edited(x).logits, x).backward()
|
109 |
+
import pdb; pdb.set_trace()
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == '__main__':
|
113 |
+
with torch.autograd.set_detect_anomaly(True):
|
114 |
+
test()
|
algs/ft.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import higher
|
4 |
+
from higher.patch import monkeypatch as make_functional
|
5 |
+
import time
|
6 |
+
|
7 |
+
from editable_model import EditableModel
|
8 |
+
from utils import _logits, _inner_params
|
9 |
+
from losses import kl_loc_loss
|
10 |
+
|
11 |
+
|
12 |
+
class FT(EditableModel):
|
13 |
+
"""
|
14 |
+
Fine-tuning approach. Does not require training.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, model, config, model_constructor, edit_loss_fn=None):
|
18 |
+
super().__init__(model, config, model_constructor)
|
19 |
+
|
20 |
+
if edit_loss_fn is not None:
|
21 |
+
self.edit_loss_fn = edit_loss_fn
|
22 |
+
|
23 |
+
self.locality_loss_fn = kl_loc_loss
|
24 |
+
self.loc_ids = None
|
25 |
+
self.loc_masks = None
|
26 |
+
self.loc_sampler = None
|
27 |
+
|
28 |
+
def _edit_loss(self, model, p0, p_edited, edit_batch):
|
29 |
+
output = _logits(model(**edit_batch, params=p_edited))
|
30 |
+
loss_dict = self.edit_loss_fn(output, edit_batch["labels"])
|
31 |
+
l_edit, acc = loss_dict["nll"], loss_dict["acc"]
|
32 |
+
if self.config.ft.locality.enabled:
|
33 |
+
if self.config.ft.locality.oracle:
|
34 |
+
loc_batch = next(self.loc_sampler)["loc"]
|
35 |
+
else:
|
36 |
+
raise NotImplementedError
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
original_base_logits = _logits(model(**loc_batch, params=p0))
|
40 |
+
edited_base_logits = _logits(model(**loc_batch, params=p_edited))
|
41 |
+
kl_mask = loc_batch.get("decoder_attention_mask", loc_batch["attention_mask"])
|
42 |
+
l_loc = self.locality_loss_fn(original_base_logits, edited_base_logits, mask=kl_mask)
|
43 |
+
loss = l_loc + self.config.ft.locality.cedit * l_edit
|
44 |
+
else:
|
45 |
+
l_loc = torch.tensor(float('nan'))
|
46 |
+
loss = l_edit
|
47 |
+
return loss, l_edit, l_loc, acc
|
48 |
+
|
49 |
+
def accuracy(self, output, labels):
|
50 |
+
if output.shape[-1] != 1:
|
51 |
+
shifted_output = output.argmax(-1)[:, :-1]
|
52 |
+
shifted_labels = labels[:, 1:]
|
53 |
+
to_predict = (shifted_labels != -100).sum()
|
54 |
+
correct = (shifted_output == shifted_labels).sum()
|
55 |
+
acc = correct.float() / to_predict.float()
|
56 |
+
else:
|
57 |
+
acc = ((output > 0) == labels.bool()).sum().float()
|
58 |
+
return acc
|
59 |
+
|
60 |
+
def _edit_status(self, step, loss, l_edit, l_loc, acc, res_p):
|
61 |
+
return (
|
62 |
+
f"step: {step}".ljust(14) +
|
63 |
+
f"loss: {loss.item():.5f}".ljust(18) +
|
64 |
+
f"l_edit: {l_edit.item():.5f}".ljust(18) +
|
65 |
+
f"l_loc: {l_loc.item():.5f}".ljust(18) +
|
66 |
+
f"acc: {acc.item():.2f}".ljust(14) +
|
67 |
+
f"norm: {res_p.view(-1).norm().item():.5f}"
|
68 |
+
)
|
69 |
+
|
70 |
+
def edit(self, batch, condition=None, detach_history=False):
|
71 |
+
edit_model = self.model.eval()
|
72 |
+
p0 = list(edit_model.named_parameters())
|
73 |
+
|
74 |
+
if not isinstance(edit_model, higher.patch._MonkeyPatchBase):
|
75 |
+
edit_model = make_functional(self.model, track_higher_grads=False, in_place=True)
|
76 |
+
|
77 |
+
packed_residuals = {}
|
78 |
+
opt_params = []
|
79 |
+
for n, p in _inner_params(edit_model.named_parameters(), self.config.model.inner_params):
|
80 |
+
if self.config.ft.rank is not None:
|
81 |
+
u = nn.Parameter(torch.randn(p.shape[0], self.config.ft.rank, device=p.device) * self.config.ft.init_std)
|
82 |
+
v = nn.Parameter(torch.zeros(self.config.ft.rank, p.shape[1], device=p.device))
|
83 |
+
res = [u, v]
|
84 |
+
else:
|
85 |
+
res = [nn.Parameter(torch.zeros_like(p, device=p.device))]
|
86 |
+
|
87 |
+
packed_residuals[n] = res
|
88 |
+
opt_params.extend(res)
|
89 |
+
|
90 |
+
assert len(opt_params) == len(self.config.model.inner_params)
|
91 |
+
OptClass = getattr(torch.optim, self.config.ft.opt)
|
92 |
+
opt = OptClass(opt_params, lr=self.config.edit_lr)
|
93 |
+
|
94 |
+
start_time = time.time()
|
95 |
+
for edit_step in range(self.config.ft.max_edit_steps):
|
96 |
+
if self.config.ft.time_limit is not None and (time.time() - start_time > self.config.ft.time_limit):
|
97 |
+
break
|
98 |
+
residuals = {k: v[0] @ v[1] if len(v) == 2 else v[0] for k, v in packed_residuals.items()}
|
99 |
+
edited_params = [p if n not in residuals else p.detach() + residuals[n] for n, p in p0]
|
100 |
+
loss, l_edit, l_loc, acc = self._edit_loss(edit_model, [p for n, p in p0], edited_params, batch)
|
101 |
+
|
102 |
+
if self.config.ft.verbose:
|
103 |
+
residual = list(residuals.values())[-1]
|
104 |
+
print(self._edit_status(edit_step, loss, l_edit, l_loc, acc, residual), end="\r")
|
105 |
+
|
106 |
+
if acc == 1.0:
|
107 |
+
break
|
108 |
+
|
109 |
+
for p, g in zip(opt_params, torch.autograd.grad(loss, opt_params)):
|
110 |
+
p.grad = g
|
111 |
+
torch.nn.utils.clip_grad_norm_(opt_params, self.config.grad_clip)
|
112 |
+
opt.step()
|
113 |
+
opt.zero_grad()
|
114 |
+
|
115 |
+
if detach_history:
|
116 |
+
new_model = self.model_constructor()
|
117 |
+
new_model.load_state_dict(edit_model.state_dict())
|
118 |
+
edit_model = new_model
|
119 |
+
edit_model.train(self.training)
|
120 |
+
|
121 |
+
return FT(edit_model, self.config, self.model_constructor, self.edit_loss_fn), {}
|
algs/ke.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/nicola-decao/KnowledgeEditor/blob/main/src/models/one_shot_learner.py
|
2 |
+
"""
|
3 |
+
@inproceedings{decao2020editing,
|
4 |
+
title={Editing Factual Knowledge in Language Models},
|
5 |
+
author={Nicola De Cao and Wilker Aziz and Ivan Titov},
|
6 |
+
booktitle={arXiv pre-print 2104.08164},
|
7 |
+
url={https://arxiv.org/abs/2104.08164},
|
8 |
+
year={2021},
|
9 |
+
}
|
10 |
+
"""
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import copy
|
14 |
+
import higher
|
15 |
+
from higher.patch import monkeypatch as make_functional
|
16 |
+
from allennlp.modules.feedforward import FeedForward
|
17 |
+
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper
|
18 |
+
import logging
|
19 |
+
|
20 |
+
from editable_model import EditableModel
|
21 |
+
from utils import _logits, _inner_params
|
22 |
+
from models import BertClassifier
|
23 |
+
from transformers import BartForConditionalGeneration, T5ForConditionalGeneration
|
24 |
+
|
25 |
+
|
26 |
+
LOG = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class KE(EditableModel):
|
30 |
+
def __init__(self, model, config, model_constructor, editor=None):
|
31 |
+
super().__init__(model, config, model_constructor)
|
32 |
+
|
33 |
+
if editor is None:
|
34 |
+
if isinstance(model, BertClassifier):
|
35 |
+
embedding = model.model.embeddings.word_embeddings.weight.data
|
36 |
+
elif isinstance(model, BartForConditionalGeneration):
|
37 |
+
embedding = model.model.shared.weight.data
|
38 |
+
elif isinstance(model, T5ForConditionalGeneration):
|
39 |
+
embedding = model.shared.weight.data
|
40 |
+
else:
|
41 |
+
embedding = model.transformer.wte.weight.data
|
42 |
+
|
43 |
+
editor = OneShotLearner(model, vocab_dim=model.config.vocab_size,
|
44 |
+
include_set=config.model.inner_params,
|
45 |
+
embedding_dim=embedding.shape[-1],
|
46 |
+
embedding_init=embedding.clone().to(torch.float32),
|
47 |
+
max_scale=1)
|
48 |
+
self.editor = editor
|
49 |
+
|
50 |
+
def outer_parameters(self, grouped=False):
|
51 |
+
if grouped:
|
52 |
+
return [
|
53 |
+
dict(params=self.editor.parameters(), lr=self.config.lr)
|
54 |
+
]
|
55 |
+
else:
|
56 |
+
return list(self.editor.parameters())
|
57 |
+
|
58 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
59 |
+
state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict
|
60 |
+
model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params
|
61 |
+
for k in model_keys:
|
62 |
+
del state_dict[f"model.{k}"]
|
63 |
+
state_dict["model_config"] = self.model.config # Include model config
|
64 |
+
return state_dict
|
65 |
+
|
66 |
+
def load_state_dict(self, state_dict, strict: bool = True):
|
67 |
+
config = state_dict["model_config"]
|
68 |
+
del state_dict["model_config"]
|
69 |
+
if config != self.model.config:
|
70 |
+
LOG.info("Loaded model config doesn't match current model config.")
|
71 |
+
LOG.info(f"Loaded: {config}")
|
72 |
+
LOG.info(f"Current: {self.model.config}")
|
73 |
+
|
74 |
+
res = super().load_state_dict(state_dict, False)
|
75 |
+
# We should only have missing keys for the model, and no unexpected keys
|
76 |
+
assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model."
|
77 |
+
assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys"
|
78 |
+
return res
|
79 |
+
|
80 |
+
def edit(self, batch, condition, detach_history=False):
|
81 |
+
outputs = _logits(self.model(**batch))
|
82 |
+
loss = self.edit_loss_fn(outputs, batch["labels"])["nll"]
|
83 |
+
|
84 |
+
names = set([n for n, p in self.model.named_parameters()])
|
85 |
+
pset = set(self.config.model.inner_params)
|
86 |
+
for p in pset:
|
87 |
+
assert p in names, f"inner param {p} not in model"
|
88 |
+
|
89 |
+
grads = torch.autograd.grad(
|
90 |
+
loss,
|
91 |
+
[p for (n, p) in _inner_params(self.model.named_parameters(), self.config.model.inner_params)]
|
92 |
+
)
|
93 |
+
|
94 |
+
params_dict = self.editor(
|
95 |
+
condition["input_ids"] if condition is not None else batch["input_ids"],
|
96 |
+
condition["attention_mask"] if condition is not None else batch["attention_mask"],
|
97 |
+
{n: g.to(torch.float32) for (n, g) in zip(self.config.model.inner_params, grads)},
|
98 |
+
)
|
99 |
+
|
100 |
+
edited_model = self.model
|
101 |
+
if not isinstance(edited_model, higher.patch._MonkeyPatchBase):
|
102 |
+
edited_model = make_functional(edited_model, in_place=True)
|
103 |
+
|
104 |
+
def new_param(n, p):
|
105 |
+
if n not in params_dict:
|
106 |
+
return p
|
107 |
+
|
108 |
+
if p.shape[0] == params_dict[n].shape[0]:
|
109 |
+
return p + params_dict[n]
|
110 |
+
else:
|
111 |
+
return p + params_dict[n].T
|
112 |
+
|
113 |
+
edited_model.update_params(
|
114 |
+
[new_param(n, p) for (n, p) in edited_model.named_parameters()]
|
115 |
+
)
|
116 |
+
|
117 |
+
if detach_history:
|
118 |
+
new_model = self.model_constructor()
|
119 |
+
new_model.load_state_dict(edited_model.state_dict())
|
120 |
+
edited_model = new_model
|
121 |
+
|
122 |
+
return KE(edited_model, self.config, self.model_constructor, editor=self.editor), {}
|
123 |
+
|
124 |
+
|
125 |
+
class ConditionedParameter(torch.nn.Module):
|
126 |
+
def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1):
|
127 |
+
super().__init__()
|
128 |
+
self.parameter_shape = parameter.shape
|
129 |
+
|
130 |
+
if len(self.parameter_shape) == 2:
|
131 |
+
self.conditioners = torch.nn.Sequential(
|
132 |
+
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
|
133 |
+
torch.nn.Tanh(),
|
134 |
+
torch.nn.utils.weight_norm(
|
135 |
+
torch.nn.Linear(
|
136 |
+
hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1
|
137 |
+
)
|
138 |
+
),
|
139 |
+
)
|
140 |
+
elif len(self.parameter_shape) == 1:
|
141 |
+
self.conditioners = torch.nn.Sequential(
|
142 |
+
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
|
143 |
+
torch.nn.Tanh(),
|
144 |
+
torch.nn.utils.weight_norm(
|
145 |
+
torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1)
|
146 |
+
),
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
raise RuntimeError()
|
150 |
+
|
151 |
+
self.max_scale = max_scale
|
152 |
+
|
153 |
+
def forward(self, inputs, grad):
|
154 |
+
if inputs.shape[0] > 1:
|
155 |
+
raise RuntimeError("Can only condition on batches of size 1")
|
156 |
+
|
157 |
+
if len(self.parameter_shape) == 2:
|
158 |
+
(
|
159 |
+
conditioner_cola,
|
160 |
+
conditioner_rowa,
|
161 |
+
conditioner_colb,
|
162 |
+
conditioner_rowb,
|
163 |
+
conditioner_norm,
|
164 |
+
) = self.conditioners(inputs).split(
|
165 |
+
[
|
166 |
+
self.parameter_shape[1],
|
167 |
+
self.parameter_shape[0],
|
168 |
+
self.parameter_shape[1],
|
169 |
+
self.parameter_shape[0],
|
170 |
+
1,
|
171 |
+
],
|
172 |
+
dim=-1,
|
173 |
+
)
|
174 |
+
|
175 |
+
a = conditioner_rowa.softmax(-1).T @ conditioner_cola
|
176 |
+
b = conditioner_rowb.softmax(-1).T @ conditioner_colb
|
177 |
+
|
178 |
+
elif len(self.parameter_shape) == 1:
|
179 |
+
a, b, conditioner_norm = self.conditioners(inputs).split(
|
180 |
+
[self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1
|
181 |
+
)
|
182 |
+
else:
|
183 |
+
raise RuntimeError()
|
184 |
+
|
185 |
+
if a.squeeze().shape[0] != grad.shape[0]:
|
186 |
+
return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze().T + b.squeeze().T)
|
187 |
+
else:
|
188 |
+
return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze() + b.squeeze())
|
189 |
+
|
190 |
+
|
191 |
+
class LSTMConditioner(torch.nn.Module):
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
vocab_dim=30522,
|
195 |
+
embedding_dim=768,
|
196 |
+
hidden_dim=256,
|
197 |
+
output_dim=1024,
|
198 |
+
embedding_init=None,
|
199 |
+
):
|
200 |
+
super().__init__()
|
201 |
+
self.embedding = torch.nn.Embedding(
|
202 |
+
num_embeddings=vocab_dim,
|
203 |
+
embedding_dim=embedding_dim,
|
204 |
+
padding_idx=0,
|
205 |
+
_weight=embedding_init,
|
206 |
+
)
|
207 |
+
self.lstm = PytorchSeq2VecWrapper(
|
208 |
+
torch.nn.LSTM(
|
209 |
+
input_size=embedding_dim,
|
210 |
+
hidden_size=hidden_dim,
|
211 |
+
num_layers=1,
|
212 |
+
bidirectional=True,
|
213 |
+
batch_first=True,
|
214 |
+
)
|
215 |
+
)
|
216 |
+
self.linear = FeedForward(
|
217 |
+
input_dim=hidden_dim * 2,
|
218 |
+
num_layers=1,
|
219 |
+
hidden_dims=[output_dim],
|
220 |
+
activations=[torch.nn.Tanh()],
|
221 |
+
)
|
222 |
+
|
223 |
+
def forward(self, inputs, masks):
|
224 |
+
return self.linear(self.lstm(self.embedding(inputs), masks))
|
225 |
+
|
226 |
+
|
227 |
+
class OneShotLearner(torch.nn.Module):
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
model,
|
231 |
+
vocab_dim,
|
232 |
+
embedding_dim=768,
|
233 |
+
hidden_dim=512,
|
234 |
+
condition_dim=768,
|
235 |
+
include_set={},
|
236 |
+
max_scale=1e-3,
|
237 |
+
embedding_init=None,
|
238 |
+
):
|
239 |
+
super().__init__()
|
240 |
+
|
241 |
+
self.param2conditioner_map = {
|
242 |
+
n: "{}_conditioner".format(n).replace(".", "_")
|
243 |
+
for n, p in model.named_parameters()
|
244 |
+
if n in include_set
|
245 |
+
}
|
246 |
+
|
247 |
+
self.conditioners = torch.nn.ModuleDict(
|
248 |
+
{
|
249 |
+
self.param2conditioner_map[n]: ConditionedParameter(
|
250 |
+
p,
|
251 |
+
condition_dim,
|
252 |
+
hidden_dim,
|
253 |
+
max_scale=max_scale,
|
254 |
+
)
|
255 |
+
for n, p in model.named_parameters()
|
256 |
+
if n in include_set
|
257 |
+
}
|
258 |
+
)
|
259 |
+
|
260 |
+
self.condition = LSTMConditioner(
|
261 |
+
vocab_dim,
|
262 |
+
embedding_dim,
|
263 |
+
hidden_dim,
|
264 |
+
condition_dim,
|
265 |
+
embedding_init=embedding_init,
|
266 |
+
)
|
267 |
+
|
268 |
+
def forward(self, inputs, masks, grads=None):
|
269 |
+
condition = self.condition(inputs, masks)
|
270 |
+
return {
|
271 |
+
p: self.conditioners[self.param2conditioner_map[p]](
|
272 |
+
condition,
|
273 |
+
grad=grads[p] if grads else None,
|
274 |
+
)
|
275 |
+
for p, c in self.param2conditioner_map.items()
|
276 |
+
}
|
277 |
+
|
278 |
+
|
279 |
+
if __name__ == '__main__':
|
280 |
+
import transformers
|
281 |
+
import types
|
282 |
+
|
283 |
+
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
|
284 |
+
|
285 |
+
config = types.SimpleNamespace()
|
286 |
+
config.model.inner_params = [
|
287 |
+
"transformer.h.9.mlp.c_fc.weight",
|
288 |
+
"transformer.h.9.mlp.c_proj.weight",
|
289 |
+
"transformer.h.10.mlp.c_fc.weight",
|
290 |
+
"transformer.h.10.mlp.c_proj.weight",
|
291 |
+
"transformer.h.11.mlp.c_fc.weight",
|
292 |
+
"transformer.h.11.mlp.c_proj.weight",
|
293 |
+
]
|
294 |
+
|
295 |
+
efk = KE(model, config, lambda: copy.deepcopy(model)).cuda()
|
296 |
+
|
297 |
+
x = torch.arange(20).view(1, 20).cuda() + 1000
|
298 |
+
orig_logits = efk(x).logits
|
299 |
+
edited = efk.edit(x, masks=torch.ones_like(x), labels=x)
|
300 |
+
post_logits = efk(x).logits
|
301 |
+
|
302 |
+
assert torch.allclose(orig_logits, post_logits)
|
303 |
+
|
304 |
+
orig_param = [p for (n, p) in efk.model.named_parameters() if n == config.model.inner_params[-1]][0]
|
305 |
+
edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]
|
306 |
+
|
307 |
+
print((orig_param - edited_param).abs().max())
|
308 |
+
edited.eval()
|
309 |
+
print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x))["nll"]
|
310 |
+
edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x)
|
311 |
+
print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss)
|
312 |
+
import pdb; pdb.set_trace()
|
algs/lu.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import time
|
6 |
+
|
7 |
+
from editable_model import EditableModel
|
8 |
+
from utils import _last_encoder_state, _logits
|
9 |
+
|
10 |
+
class LU(EditableModel):
|
11 |
+
"""
|
12 |
+
Representation lookup approach. Does not require training.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, model, config, model_constructor, memory=None):
|
16 |
+
super().__init__(model, config, model_constructor)
|
17 |
+
|
18 |
+
self.memory = memory
|
19 |
+
|
20 |
+
def forward(self, *inputs, **kwargs):
|
21 |
+
if "bert" in self.config.model.name.lower():
|
22 |
+
output, encoder_states = self.model(*inputs, **kwargs, output_hidden_states=True)
|
23 |
+
else:
|
24 |
+
model_output = self.model(*inputs, **kwargs, output_hidden_states=True)
|
25 |
+
encoder_states = _last_encoder_state(model_output)
|
26 |
+
output = _logits(model_output)
|
27 |
+
|
28 |
+
if self.memory is not None:
|
29 |
+
for i, encoder_state in enumerate(encoder_states):
|
30 |
+
if "gpt2" in self.config.model.name.lower():
|
31 |
+
# NOTE: broken
|
32 |
+
memory_prefixes, memory_labels = self.memory
|
33 |
+
prefix_means = encoder_state.cumsum(0).detach() / torch.arange(1, encoder_state.shape[0] + 1, device=encoder_state.device).view(-1, 1)
|
34 |
+
dist_mat = (prefix_means.unsqueeze(1) - memory_prefixes.unsqueeze(0)).norm(2, dim=-1)
|
35 |
+
|
36 |
+
min_dists, min_idxs = dist_mat.min(-1)
|
37 |
+
memory_mask = (min_dists < self.config.lu.threshold)
|
38 |
+
onehot_logits = self.config.lu.onehot_logit * F.one_hot(memory_labels[min_idxs], output.shape[-1]).float()
|
39 |
+
output[i, memory_mask] = onehot_logits[memory_mask]
|
40 |
+
elif "bart" in self.config.model.name.lower() or "t5" in self.config.model.name.lower():
|
41 |
+
avg_encoder_state = encoder_state.detach().mean(0)
|
42 |
+
memory_keys, memory_labels = self.memory
|
43 |
+
dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
|
44 |
+
closest_dist = dists.min()
|
45 |
+
closest_idx = dists.argmin()
|
46 |
+
closest_v = memory_labels[closest_idx]
|
47 |
+
|
48 |
+
if closest_dist < self.config.lu.threshold:
|
49 |
+
output[i] = torch.zeros((1, kwargs['labels'].shape[1], output.shape[2]), device=output.device)
|
50 |
+
for j, idx in enumerate(closest_v):
|
51 |
+
if j >= output.shape[1]:
|
52 |
+
break
|
53 |
+
output[i, j, idx] = self.config.lu.onehot_logit
|
54 |
+
if "t5" not in self.config.model.name.lower():
|
55 |
+
# T5 does not shift targets in the loss
|
56 |
+
output[i] = output[i].roll(-1, -2)
|
57 |
+
else:
|
58 |
+
avg_encoder_state = encoder_state.detach().mean(0)
|
59 |
+
memory_keys, memory_labels = self.memory
|
60 |
+
dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
|
61 |
+
closest_dist = dists.min()
|
62 |
+
closest_idx = dists.argmin()
|
63 |
+
closest_v = memory_labels[closest_idx]
|
64 |
+
|
65 |
+
if closest_dist < self.config.lu.threshold:
|
66 |
+
output[i] = self.config.lu.onehot_logit * (2 * closest_v - 1) # Return onehot_logit or -onehot_logit
|
67 |
+
|
68 |
+
return output
|
69 |
+
|
70 |
+
def edit(self, batch, condition=None):
|
71 |
+
edit_model = self.model.eval()
|
72 |
+
if "bert" in self.config.model.name.lower():
|
73 |
+
_, encoder_states = self.model(**batch, output_hidden_states=True)
|
74 |
+
else:
|
75 |
+
encoder_states = _last_encoder_state(self.model(**batch, output_hidden_states=True))
|
76 |
+
|
77 |
+
memory_keys = []
|
78 |
+
memory_labels = []
|
79 |
+
for encoder_state, label in zip(encoder_states, batch["labels"]):
|
80 |
+
if "gpt2" in self.config.model.name.lower():
|
81 |
+
# NOTE: broken
|
82 |
+
avg_encoder_states = (encoder_state.cumsum(0).detach() / torch.arange(1, encoder_state.shape[0] + 1, device=encoder_state.device).view(-1, 1))[-10:, :]
|
83 |
+
memory = (avg_encoder_states, label[-10:])
|
84 |
+
else:
|
85 |
+
avg_encoder_state = encoder_state.detach().mean(0)
|
86 |
+
memory_keys.append(avg_encoder_state)
|
87 |
+
memory_labels.append(label)
|
88 |
+
|
89 |
+
memory = (torch.stack(memory_keys), torch.stack(memory_labels))
|
90 |
+
return LU(self.model.eval(), self.config, self.model_constructor, memory), {}
|
algs/mend.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import copy
|
5 |
+
import transformers
|
6 |
+
import higher
|
7 |
+
import logging
|
8 |
+
from higher.patch import monkeypatch as make_functional
|
9 |
+
from collections import defaultdict
|
10 |
+
|
11 |
+
from editable_model import EditableModel
|
12 |
+
from hooks import hook_model
|
13 |
+
import nn as local_nn
|
14 |
+
from utils import _logits, _inner_params
|
15 |
+
|
16 |
+
LOG = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def update_counter(x, m, s, k):
|
20 |
+
new_m = m + (x - m) / k
|
21 |
+
new_s = s + (x - m) * (x - new_m)
|
22 |
+
|
23 |
+
return new_m, new_s
|
24 |
+
|
25 |
+
|
26 |
+
class GradientTransform(nn.Module):
|
27 |
+
def __init__(self, x_dim: int, delta_dim: int, cfg, n_modes = None):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.x_dim = x_dim
|
31 |
+
self.delta_dim = delta_dim
|
32 |
+
self.cfg = cfg
|
33 |
+
if cfg.combine and (cfg.one_sided or cfg.x_only or cfg.delta_only):
|
34 |
+
raise ValueError("cfg.combine cannot be used with one-sided MEND variants")
|
35 |
+
|
36 |
+
self.norm_init = False
|
37 |
+
self.register_buffer("u_mean", torch.full((x_dim,), float("nan")))
|
38 |
+
self.register_buffer("v_mean", torch.full((delta_dim,), float("nan")))
|
39 |
+
self.register_buffer("u_std", torch.full((x_dim,), float("nan")))
|
40 |
+
self.register_buffer("v_std", torch.full((delta_dim,), float("nan")))
|
41 |
+
self.register_buffer("u_s", torch.full((x_dim,), float("nan")))
|
42 |
+
self.register_buffer("v_s", torch.full((delta_dim,), float("nan")))
|
43 |
+
self.register_buffer("k", torch.full((1,), float("nan")))
|
44 |
+
|
45 |
+
MlpClass = getattr(local_nn, cfg.mlp_class)
|
46 |
+
LOG.info(f"Building Gradient Transform with MLP class {MlpClass}")
|
47 |
+
|
48 |
+
def delta_net():
|
49 |
+
return MlpClass(delta_dim, delta_dim, delta_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)
|
50 |
+
|
51 |
+
def x_net():
|
52 |
+
return MlpClass(x_dim, x_dim, x_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)
|
53 |
+
|
54 |
+
def combined_net():
|
55 |
+
return MlpClass(delta_dim + x_dim, delta_dim + x_dim, (delta_dim + x_dim) * 2,
|
56 |
+
cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)
|
57 |
+
|
58 |
+
def ID():
|
59 |
+
return lambda x, mode=None: x
|
60 |
+
|
61 |
+
if cfg.combine:
|
62 |
+
self.mlp = combined_net()
|
63 |
+
elif cfg.one_sided:
|
64 |
+
if x_dim > delta_dim:
|
65 |
+
self.mlp1, self.mlp2 = ID(), delta_net()
|
66 |
+
else:
|
67 |
+
self.mlp1, self.mlp2 = x_net(), ID()
|
68 |
+
elif cfg.x_only:
|
69 |
+
self.mlp1, self.mlp2 = x_net(), ID()
|
70 |
+
elif cfg.delta_only:
|
71 |
+
self.mlp1, self.mlp2 = ID(), delta_net()
|
72 |
+
else:
|
73 |
+
self.mlp1, self.mlp2 = x_net(), delta_net()
|
74 |
+
|
75 |
+
def forward(self, u, v, param_idx=None):
|
76 |
+
u, v = u.to(torch.float32), v.to(torch.float32)
|
77 |
+
|
78 |
+
u_ = u.view(-1, u.shape[-1])
|
79 |
+
v_ = v.view(-1, v.shape[-1])
|
80 |
+
|
81 |
+
nz_mask = (u_ != 0).any(-1) * (v_ != 0).any(-1) # Skip batch elements with zero grad
|
82 |
+
u_ = u_[nz_mask]
|
83 |
+
v_ = v_[nz_mask]
|
84 |
+
|
85 |
+
if self.training:
|
86 |
+
for idx in range(u_.shape[0]):
|
87 |
+
if not self.norm_init:
|
88 |
+
self.u_mean = u_[idx].clone().detach()
|
89 |
+
self.v_mean = v_[idx].clone().detach()
|
90 |
+
self.u_s.zero_()
|
91 |
+
self.v_s.zero_()
|
92 |
+
self.k[:] = 1
|
93 |
+
self.norm_init = True
|
94 |
+
else:
|
95 |
+
self.k += 1
|
96 |
+
self.u_mean, self.u_s = update_counter(u_[idx], self.u_mean, self.u_s, self.k)
|
97 |
+
self.v_mean, self.v_s = update_counter(v_[idx], self.v_mean, self.v_s, self.k)
|
98 |
+
|
99 |
+
if self.k < 2:
|
100 |
+
raise RuntimeError(f"Can't perform normalization with only {self.k} samples so far")
|
101 |
+
self.u_std = (self.u_s / (self.k - 1)) ** 0.5
|
102 |
+
self.v_std = (self.v_s / (self.k - 1)) ** 0.5
|
103 |
+
|
104 |
+
if self.cfg.norm:
|
105 |
+
u_input = (u_ - self.u_mean) / (self.u_std + 1e-7)
|
106 |
+
v_input = (v_ - self.v_mean) / (self.v_std + 1e-7)
|
107 |
+
else:
|
108 |
+
u_input = u_
|
109 |
+
v_input = v_
|
110 |
+
|
111 |
+
if self.cfg.combine:
|
112 |
+
output = self.mlp(torch.cat((u_input, v_input), -1), mode=param_idx)
|
113 |
+
out1, out2 = output.split([u.shape[-1], v.shape[-1]], -1)
|
114 |
+
return out1, out2
|
115 |
+
else:
|
116 |
+
return self.mlp1(u_input, mode=param_idx), self.mlp2(v_input, mode=param_idx)
|
117 |
+
|
118 |
+
|
119 |
+
class MEND(EditableModel):
|
120 |
+
def get_shape(self, p):
|
121 |
+
# We need to (annoyingly) flip the shapes since OpenAI gpt2 uses convs instead of linear
|
122 |
+
return p.shape if isinstance(self.model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0])
|
123 |
+
|
124 |
+
def __init__(self, model, config, model_constructor, gtn=None, edit_lrs=None):
|
125 |
+
super().__init__(model, config, model_constructor)
|
126 |
+
|
127 |
+
if edit_lrs is None:
|
128 |
+
edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params)))
|
129 |
+
self.edit_lrs = edit_lrs
|
130 |
+
|
131 |
+
if not hasattr(self.model, "handles"):
|
132 |
+
hook_model(self.model, self.config.model.inner_params)
|
133 |
+
LOG.info(f"Hooked {len(self.model.handles)//2} modules")
|
134 |
+
|
135 |
+
if config.gtn.shared:
|
136 |
+
shape_dict = defaultdict(list)
|
137 |
+
for n, p in _inner_params(model.named_parameters(), self.config.model.inner_params):
|
138 |
+
shape_dict[self.get_shape(p)].append(n)
|
139 |
+
self.shape_dict = shape_dict
|
140 |
+
|
141 |
+
if gtn is None:
|
142 |
+
if not config.gtn.shared:
|
143 |
+
self.gtn = nn.ModuleDict({
|
144 |
+
n.replace(".", "#"): GradientTransform(*self.get_shape(p), config.gtn)
|
145 |
+
for (n, p) in _inner_params(model.named_parameters(), self.config.model.inner_params)
|
146 |
+
})
|
147 |
+
else:
|
148 |
+
self.gtn = nn.ModuleDict({
|
149 |
+
str(tuple(s)): GradientTransform(*s, config.gtn, len(shape_dict[s]))
|
150 |
+
for s in shape_dict.keys()
|
151 |
+
})
|
152 |
+
else:
|
153 |
+
self.gtn = gtn
|
154 |
+
|
155 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
156 |
+
state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict
|
157 |
+
model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params
|
158 |
+
for k in model_keys:
|
159 |
+
del state_dict[f"model.{k}"]
|
160 |
+
state_dict["model_config"] = self.model.config # Include model config
|
161 |
+
return state_dict
|
162 |
+
|
163 |
+
def load_state_dict(self, state_dict, strict: bool = True):
|
164 |
+
config = state_dict["model_config"]
|
165 |
+
del state_dict["model_config"]
|
166 |
+
if config != self.model.config:
|
167 |
+
LOG.info("Loaded model config doesn't match current model config.")
|
168 |
+
LOG.info(f"Loaded: {config}")
|
169 |
+
LOG.info(f"Current: {self.model.config}")
|
170 |
+
|
171 |
+
res = super().load_state_dict(state_dict, False)
|
172 |
+
# We should only have missing keys for the model, and no unexpected keys
|
173 |
+
assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model."
|
174 |
+
assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys"
|
175 |
+
return res
|
176 |
+
|
177 |
+
def outer_parameters(self, grouped=False):
|
178 |
+
if grouped:
|
179 |
+
return [
|
180 |
+
dict(params=list(self.gtn.parameters()), lr=self.config.lr),
|
181 |
+
dict(params=[self.edit_lrs], lr=self.config.lr_lr)
|
182 |
+
]
|
183 |
+
else:
|
184 |
+
return list(self.gtn.parameters()) + [self.edit_lrs]
|
185 |
+
|
186 |
+
def edit(self, batch, condition=None, detach_history=False):
|
187 |
+
outputs = _logits(self.model(**batch))
|
188 |
+
loss = self.edit_loss_fn(outputs, batch["labels"])["nll"]
|
189 |
+
|
190 |
+
names = set([n for n, p in self.model.named_parameters()])
|
191 |
+
pset = set(self.config.model.inner_params)
|
192 |
+
for p in pset:
|
193 |
+
assert p in names, f"inner param {p} not in model"
|
194 |
+
|
195 |
+
loss.backward()
|
196 |
+
|
197 |
+
if self.config.gtn.shared:
|
198 |
+
param_idx = lambda n, p: self.shape_dict[self.get_shape(p)].index(n) if self.config.gtn.shared else None # noqa: E731
|
199 |
+
transformed_factors = {
|
200 |
+
n: self.gtn[str(tuple(self.get_shape(p)))](p.__x__, p.__delta__, param_idx(n, p))
|
201 |
+
for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params)
|
202 |
+
}
|
203 |
+
else:
|
204 |
+
transformed_factors = {
|
205 |
+
n: self.gtn[n.replace(".", "#")](p.__x__, p.__delta__)
|
206 |
+
for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params)
|
207 |
+
}
|
208 |
+
|
209 |
+
# Should be bi,bj->ji for nn.Linear, but [annoying] GPT2 uses Conv1d instead...
|
210 |
+
if isinstance(self.model, transformers.GPT2LMHeadModel):
|
211 |
+
targ = "ij"
|
212 |
+
else:
|
213 |
+
targ = "ji"
|
214 |
+
mean_grads = {
|
215 |
+
n: torch.einsum(f"bi,bj->{targ}", x, delta)
|
216 |
+
for n, (x, delta) in transformed_factors.items()
|
217 |
+
}
|
218 |
+
|
219 |
+
info_dict = {}
|
220 |
+
idx = 0
|
221 |
+
for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params):
|
222 |
+
info_dict[f"grad/true_mag{idx}"] = p.grad.norm(2).item()
|
223 |
+
info_dict[f"grad/pseudo_mag{idx}"] = mean_grads[n].norm(2).item()
|
224 |
+
info_dict[f"grad/true_std{idx}"] = p.grad.std().item()
|
225 |
+
info_dict[f"grad/pseudo_std{idx}"] = mean_grads[n].std().item()
|
226 |
+
info_dict[f"grad/diff{idx}"] = (p.grad - mean_grads[n]).norm(2).item()
|
227 |
+
info_dict[f"grad/cos{idx}"] = F.cosine_similarity(p.grad.reshape(-1), mean_grads[n].reshape(-1), dim=0).item()
|
228 |
+
idx += 1
|
229 |
+
|
230 |
+
self.model.zero_grad()
|
231 |
+
|
232 |
+
assert len(self.edit_lrs) == len(list(mean_grads.items()))
|
233 |
+
updates = {n: lr * g for lr, (n, g) in zip(self.edit_lrs, mean_grads.items())}
|
234 |
+
|
235 |
+
edited_model = self.model
|
236 |
+
if not isinstance(edited_model, higher.patch._MonkeyPatchBase):
|
237 |
+
edited_model = make_functional(edited_model, in_place=True)
|
238 |
+
|
239 |
+
new_params = []
|
240 |
+
for n, p in edited_model.named_parameters():
|
241 |
+
if n in pset:
|
242 |
+
if self.config.gtn.descent:
|
243 |
+
new_params.append(p - updates[n])
|
244 |
+
else:
|
245 |
+
new_params.append(p + updates[n])
|
246 |
+
else:
|
247 |
+
new_params.append(p)
|
248 |
+
|
249 |
+
edited_model.update_params(new_params)
|
250 |
+
|
251 |
+
if detach_history:
|
252 |
+
new_model = self.model_constructor()
|
253 |
+
new_model.load_state_dict(edited_model.state_dict())
|
254 |
+
edited_model = new_model
|
255 |
+
|
256 |
+
return MEND(edited_model, self.config, self.model_constructor, self.gtn, edit_lrs=self.edit_lrs), info_dict
|
257 |
+
|
258 |
+
|
259 |
+
if __name__ == '__main__':
|
260 |
+
import types
|
261 |
+
|
262 |
+
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
|
263 |
+
|
264 |
+
config = types.SimpleNamespace()
|
265 |
+
config.model.inner_params = [
|
266 |
+
"transformer.h.9.mlp.c_fc.weight",
|
267 |
+
"transformer.h.9.mlp.c_proj.weight",
|
268 |
+
"transformer.h.10.mlp.c_fc.weight",
|
269 |
+
"transformer.h.10.mlp.c_proj.weight",
|
270 |
+
"transformer.h.11.mlp.c_fc.weight",
|
271 |
+
"transformer.h.11.mlp.c_proj.weight",
|
272 |
+
]
|
273 |
+
config.edit_lr = 0.0001
|
274 |
+
|
275 |
+
config.gtn = types.SimpleNamespace()
|
276 |
+
config.gtn.n_hidden = 1
|
277 |
+
config.gtn = config.gtn.__dict__
|
278 |
+
|
279 |
+
gtn = MEND(model, config, lambda: copy.deepcopy(model)).cuda()
|
280 |
+
# torch.save(gtn.state_dict(), "test_state.pt")
|
281 |
+
import pdb; pdb.set_trace()
|
282 |
+
gtn.load_state_dict(torch.load("test_state.pt"))
|
283 |
+
x = torch.arange(20).view(1, 20).cuda() + 1000
|
284 |
+
orig_logits = gtn(x)
|
285 |
+
edited = gtn.edit(x, masks=torch.ones_like(x), labels=x)
|
286 |
+
post_logits = gtn(x)
|
287 |
+
|
288 |
+
assert torch.allclose(orig_logits, post_logits)
|
289 |
+
|
290 |
+
orig_param = [p for (n, p) in gtn.model.named_parameters() if n == config.model.inner_params[-1]][0]
|
291 |
+
edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]
|
292 |
+
|
293 |
+
LOG.info((orig_param - edited_param).abs().max())
|
294 |
+
edited.eval()
|
295 |
+
LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"])
|
296 |
+
edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x)
|
297 |
+
LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss)
|
algs/serac.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
import transformers
|
4 |
+
import logging
|
5 |
+
|
6 |
+
from utils import scr, set_dropout, _logits, add_padding, add_sep
|
7 |
+
from editable_model import EditableModel
|
8 |
+
from models import BertClassifier
|
9 |
+
|
10 |
+
LOG = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def translate_tokens(tokens, from_tok, to_tok):
|
14 |
+
tokens = tokens.masked_fill(tokens == -100, from_tok.pad_token_id)
|
15 |
+
text = from_tok.batch_decode(tokens, skip_special_tokens=True)
|
16 |
+
return to_tok(text, return_tensors="pt")["input_ids"].to(tokens.device)
|
17 |
+
|
18 |
+
|
19 |
+
class SERAC(EditableModel):
|
20 |
+
def __init__(self, model, config, model_constructor, classifier=None, classifier_tok=None,
|
21 |
+
replacement=None, replacement_tok=None, cache_inputs=None, cache_labels=None,
|
22 |
+
cache_embeds=None, scale=None):
|
23 |
+
super().__init__(model, config, model_constructor)
|
24 |
+
|
25 |
+
if classifier is None:
|
26 |
+
if config.rep.cross_attend and not config.rep.cls_class.endswith("ForSequenceClassification"):
|
27 |
+
LOG.warn(f"Switching {config.rep.cls_class} to {config.rep.cls_class}ForSequenceClassification for cross-attend")
|
28 |
+
config.rep.cls_class += "ForSequenceClassification"
|
29 |
+
self.classifier = getattr(transformers, config.rep.cls_class).from_pretrained(config.rep.cls_name, cache_dir=scr())
|
30 |
+
if self.config.rep.checkpoint_grad:
|
31 |
+
LOG.info(f"Checking for checkpointing: {hasattr(self.classifier.config, 'gradient_checkpointing')}")
|
32 |
+
self.classifier.config.gradient_checkpointing = True
|
33 |
+
self.classifier_tok = transformers.AutoTokenizer.from_pretrained(config.rep.cls_name, cache_dir=scr())
|
34 |
+
if not self.config.rep.cross_attend and 'bert' in self.config.rep.cls_name:
|
35 |
+
self.classifier.pooler = None # we don't need the classification head
|
36 |
+
elif not self.config.rep.cross_attend and "mpnet" not in self.config.rep.cls_name:
|
37 |
+
if hasattr(self.classifier, "pooler"):
|
38 |
+
self.classifier.pooler = None # we don't need the classification head
|
39 |
+
|
40 |
+
set_dropout(self.classifier, config.dropout)
|
41 |
+
if self.config.rep.lora is not None:
|
42 |
+
self.classifier = LoraModel(self.classifier, self.config.rep.lora)
|
43 |
+
else:
|
44 |
+
assert isinstance(classifier, torch.nn.Module), f"Classifier is a {type(classifier)}!"
|
45 |
+
assert isinstance(classifier_tok, transformers.PreTrainedTokenizerBase), f"Classifier tok is {type(classifier_tok)}!"
|
46 |
+
self.classifier, self.classifier_tok = classifier, classifier_tok
|
47 |
+
|
48 |
+
if replacement is None:
|
49 |
+
# self.replacement_tok = getattr(transformers, config.model.tokenizer_class).from_pretrained(config.model.tokenizer_name,
|
50 |
+
# cache_dir=scr())
|
51 |
+
self.replacement_tok = transformers.AutoTokenizer.from_pretrained(config.model.small_name, cache_dir=scr())
|
52 |
+
# if self.replacement_tok.sep_token is None:
|
53 |
+
# self.replacement_tok.sep_token = self.replacement_tok.eos_token
|
54 |
+
if (False and self.config.rep.freeze_cntr):
|
55 |
+
self.replacement = None
|
56 |
+
else:
|
57 |
+
if config.model.class_name == "BertClassifier":
|
58 |
+
self.replacement = BertClassifier(config.model.small_name)
|
59 |
+
else:
|
60 |
+
self.replacement = getattr(transformers, config.model.class_name).from_pretrained(config.model.small_name, cache_dir=scr())
|
61 |
+
if self.replacement_tok.sep_token is None and "gpt" not in self.model.name_or_path.lower():
|
62 |
+
add_sep(self.replacement_tok, self.replacement)
|
63 |
+
if self.replacement_tok.pad_token is None:
|
64 |
+
add_padding(self.replacement_tok, self.replacement)
|
65 |
+
set_dropout(self.replacement, config.dropout)
|
66 |
+
else:
|
67 |
+
assert isinstance(replacement, torch.nn.Module), "Rep is {type(replacement)}!"
|
68 |
+
assert isinstance(replacement_tok, transformers.PreTrainedTokenizerBase), "Rep tok is {type(replacement_tok)}!"
|
69 |
+
self.replacement, self.replacement_tok = replacement, replacement_tok
|
70 |
+
|
71 |
+
if self.config.rep.cross_attend:
|
72 |
+
self.scale = None
|
73 |
+
else:
|
74 |
+
if scale is None:
|
75 |
+
self.register_buffer("scale", torch.tensor(1.0))
|
76 |
+
# self.scale = nn.Parameter(torch.tensor(1.0))
|
77 |
+
else:
|
78 |
+
self.scale = scale
|
79 |
+
|
80 |
+
if cache_inputs is None:
|
81 |
+
self.cache_inputs = []
|
82 |
+
self.cache_labels = []
|
83 |
+
if config.rep.cache_embeds and not config.rep.cross_attend:
|
84 |
+
self.cache_embeds = {}
|
85 |
+
else:
|
86 |
+
assert isinstance(cache_inputs, list), f"Cache inputs is {cache_inputs}"
|
87 |
+
assert isinstance(cache_labels, list), f"Cache labels is {cache_labels}"
|
88 |
+
self.cache_inputs = copy.deepcopy(cache_inputs)
|
89 |
+
self.cache_labels = copy.deepcopy(cache_labels)
|
90 |
+
if config.rep.cache_embeds and not config.rep.cross_attend:
|
91 |
+
assert isinstance(cache_embeds, dict), f"Cache embeds is {cache_embeds}"
|
92 |
+
self.cache_embeds = copy.deepcopy(cache_embeds)
|
93 |
+
|
94 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
95 |
+
state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict
|
96 |
+
model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params
|
97 |
+
for k in model_keys:
|
98 |
+
del state_dict[f"model.{k}"]
|
99 |
+
if self.config.rep.freeze_cntr:
|
100 |
+
cntr_keys = self.replacement.state_dict().keys()
|
101 |
+
for k in cntr_keys:
|
102 |
+
del state_dict[f"replacement.{k}"]
|
103 |
+
state_dict["model_config"] = self.model.config # Include model config
|
104 |
+
return state_dict
|
105 |
+
|
106 |
+
def load_state_dict(self, state_dict, strict: bool = True):
|
107 |
+
config = state_dict["model_config"]
|
108 |
+
del state_dict["model_config"]
|
109 |
+
if config != self.model.config:
|
110 |
+
LOG.info("Loaded model config doesn't match current model config.")
|
111 |
+
LOG.info(f"Loaded: {config}")
|
112 |
+
LOG.info(f"Current: {self.model.config}")
|
113 |
+
|
114 |
+
if (False and self.config.rep.freeze_cntr):
|
115 |
+
rep_keys = list(state_dict.keys())
|
116 |
+
for k in rep_keys:
|
117 |
+
if k.startswith("replacement"):
|
118 |
+
del state_dict[k]
|
119 |
+
res = super().load_state_dict(state_dict, False)
|
120 |
+
else:
|
121 |
+
try:
|
122 |
+
res = super().load_state_dict(state_dict, False)
|
123 |
+
except RuntimeError:
|
124 |
+
LOG.info("Load failed; trying again without loading counterfactual model weights.")
|
125 |
+
rep_keys = list(state_dict.keys())
|
126 |
+
for k in rep_keys:
|
127 |
+
if k.startswith("replacement"):
|
128 |
+
del state_dict[k]
|
129 |
+
res = super().load_state_dict(state_dict, False)
|
130 |
+
|
131 |
+
# We should only have missing keys for the model, and no unexpected keys
|
132 |
+
def ok_to_miss(k):
|
133 |
+
return k.startswith("model.") or ((False and self.config.rep.freeze_cntr) and k.startswith("replacement."))
|
134 |
+
missing_keys = [k for k in res.missing_keys if not ok_to_miss(k)]
|
135 |
+
assert len(missing_keys) == 0, f"Should only have missing keys for model: {missing_keys}."
|
136 |
+
assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys"
|
137 |
+
return res
|
138 |
+
|
139 |
+
def outer_parameters(self, grouped=False):
|
140 |
+
if self.config.rep.freeze is not None:
|
141 |
+
modlist = None
|
142 |
+
for m in self.classifier.modules():
|
143 |
+
if isinstance(m, torch.nn.ModuleList):
|
144 |
+
modlist = m
|
145 |
+
break
|
146 |
+
model_params = list(modlist[-self.config.rep.freeze:].parameters())
|
147 |
+
else:
|
148 |
+
model_params = list(self.classifier.parameters())
|
149 |
+
|
150 |
+
if self.config.rep.lora is not None or self.config.rep.freeze is not None:
|
151 |
+
cls = self.classifier.base_model if self.config.rep.lora else self.classifier
|
152 |
+
if hasattr(cls, "classifier"):
|
153 |
+
model_params.extend(cls.classifier.parameters())
|
154 |
+
if hasattr(cls, "pre_classifier"):
|
155 |
+
model_params.extend(cls.pre_classifier.parameters())
|
156 |
+
|
157 |
+
if not (False and self.config.rep.freeze_cntr):
|
158 |
+
model_params.extend(list(self.replacement.parameters()))
|
159 |
+
|
160 |
+
extra_params = []
|
161 |
+
if grouped:
|
162 |
+
return [
|
163 |
+
dict(params=model_params, lr=self.config.lr),
|
164 |
+
dict(params=extra_params, lr=self.config.lr_lr)
|
165 |
+
]
|
166 |
+
else:
|
167 |
+
return model_params + extra_params
|
168 |
+
|
169 |
+
def edit(self, batch, condition=None, detach_history=False):
|
170 |
+
def detokenize(toks, tok):
|
171 |
+
tokens = toks.masked_fill(toks == -100, tok.pad_token_id)
|
172 |
+
return tok.batch_decode(tokens, skip_special_tokens=True)
|
173 |
+
|
174 |
+
inputs = detokenize(batch["input_ids"], self.replacement_tok)
|
175 |
+
if "bert" in self.config.model.name:
|
176 |
+
labels = ["" for _ in batch["labels"]]
|
177 |
+
else:
|
178 |
+
labels = detokenize(batch["labels"], self.replacement_tok)
|
179 |
+
|
180 |
+
cache_inputs = self.cache_inputs + inputs
|
181 |
+
cache_labels = self.cache_labels + labels
|
182 |
+
|
183 |
+
if self.config.rep.cache_embeds and not self.config.rep.cross_attend:
|
184 |
+
cls_inputs = self.build_cls_cache_inputs(inputs, labels)
|
185 |
+
with torch.no_grad():
|
186 |
+
embeds = self.compute_cls_embeddings(cls_inputs)
|
187 |
+
|
188 |
+
cache_embeds = {inp: emb for inp, emb in zip(cls_inputs, embeds)}
|
189 |
+
cache_embeds.update(self.cache_embeds)
|
190 |
+
else:
|
191 |
+
cache_embeds = None
|
192 |
+
|
193 |
+
new_model = SERAC(self.model, self.config, self.model_constructor, self.classifier, self.classifier_tok,
|
194 |
+
self.replacement, self.replacement_tok, cache_inputs, cache_labels, cache_embeds, self.scale)
|
195 |
+
new_model.train(self.training)
|
196 |
+
return new_model, {}
|
197 |
+
|
198 |
+
def stats(self):
|
199 |
+
return self.last_stats
|
200 |
+
|
201 |
+
def compute_cls_embeddings(self, text):
|
202 |
+
inputs = self.classifier_tok(text, return_tensors="pt", padding=True).to(self.config.device)
|
203 |
+
if 'bert' in self.config.rep.cls_name:
|
204 |
+
embeds = self.classifier(**inputs).last_hidden_state[:, 0].unsqueeze(1)
|
205 |
+
else:
|
206 |
+
embeds = self.classifier(**inputs).pooler_output.unsqueeze(1)
|
207 |
+
embeds = embeds.view(embeds.shape[0], self.config.rep.dist_heads, -1)
|
208 |
+
if self.config.rep.bound_embeds:
|
209 |
+
embeds = embeds.tanh()
|
210 |
+
return embeds
|
211 |
+
|
212 |
+
def embedding_logsim_matrix(self, cls_ctxs, test_input_text):
|
213 |
+
if self.config.rep.cache_embeds and not self.config.rep.cross_attend and not self.training:
|
214 |
+
ctx_embeds = torch.cat([self.cache_embeds[ctx] for ctx in cls_ctxs])
|
215 |
+
else:
|
216 |
+
ctx_embeds = self.compute_cls_embeddings(cls_ctxs)
|
217 |
+
main_embeds = self.compute_cls_embeddings(test_input_text)
|
218 |
+
|
219 |
+
if self.config.rep.cos:
|
220 |
+
cos = (ctx_embeds[None] * main_embeds[:, None]).sum(-1) / (ctx_embeds[None].norm(2, -1) * main_embeds[:, None].norm(2, -1))
|
221 |
+
dists = 1 - cos
|
222 |
+
else:
|
223 |
+
dists = (ctx_embeds[None] - main_embeds[:, None]).norm(2, -1)
|
224 |
+
if self.config.rep.square:
|
225 |
+
dists = dists ** 2
|
226 |
+
|
227 |
+
dists = dists.min(-1).values # get rid of the dists head dimension
|
228 |
+
|
229 |
+
assert dists.min() >= 0, "Shouldn't have negative distances!"
|
230 |
+
cls_logsims = -dists * self.scale
|
231 |
+
|
232 |
+
return cls_logsims
|
233 |
+
|
234 |
+
def crossattend_logsim_matrix(self, cls_ctxs, test_input_texts):
|
235 |
+
batch = [ctx + self.classifier_tok.sep_token + test for test in test_input_texts for ctx in cls_ctxs]
|
236 |
+
batch_toks = self.classifier_tok(batch, return_tensors="pt", padding=True).to(self.config.device)
|
237 |
+
batch_logsims = self.classifier(**batch_toks).logits.log_softmax(-1)[:, 0]
|
238 |
+
logsim_matrix = batch_logsims.view(len(test_input_texts), len(cls_ctxs))
|
239 |
+
|
240 |
+
return logsim_matrix
|
241 |
+
|
242 |
+
def build_rep_cache_contexts(self):
|
243 |
+
sep = " "
|
244 |
+
if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
|
245 |
+
# The labels are include in the inputs for autoregressive models. Cut off the label for the classifier
|
246 |
+
ctxs = [cin + sep for cin in self.cache_inputs]
|
247 |
+
else:
|
248 |
+
ctxs = [cin + sep + clab + sep for cin, clab in zip(self.cache_inputs, self.cache_labels)]
|
249 |
+
return ctxs
|
250 |
+
|
251 |
+
def build_cls_cache_inputs(self, cache_inputs=None, cache_labels=None):
|
252 |
+
sep = self.classifier_tok.sep_token
|
253 |
+
if cache_inputs is None:
|
254 |
+
cache_inputs = self.cache_inputs
|
255 |
+
if cache_labels is None:
|
256 |
+
cache_labels = self.cache_labels
|
257 |
+
|
258 |
+
if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
|
259 |
+
# The labels are include in the inputs for autoregressive models. Cut off the label for the classifier
|
260 |
+
inputs = [cin.rsplit(" ", 1)[0] + sep for cin in cache_inputs]
|
261 |
+
else:
|
262 |
+
inputs = [cin + sep + clab + sep for cin, clab in zip(cache_inputs, cache_labels)]
|
263 |
+
return inputs
|
264 |
+
|
265 |
+
def build_rep_input_tokens(self, kwargs, idxs, generation=False):
|
266 |
+
assert len(idxs) == len(kwargs["input_ids"]), "Need one cache idx for each test input"
|
267 |
+
cache_contexts = self.build_rep_cache_contexts()
|
268 |
+
selected_contexts = [cache_contexts[idx.item()] for idx in idxs]
|
269 |
+
test_inputs = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True)
|
270 |
+
rep_texts = [ctx + inp for ctx, inp in zip(selected_contexts, test_inputs)]
|
271 |
+
rep_input_tokens = self.replacement_tok(rep_texts, return_tensors="pt", padding=True).to(self.config.device)
|
272 |
+
|
273 |
+
rep_kwargs = {
|
274 |
+
"input_ids": rep_input_tokens["input_ids"],
|
275 |
+
"attention_mask": rep_input_tokens["attention_mask"],
|
276 |
+
}
|
277 |
+
|
278 |
+
if not generation:
|
279 |
+
rep_kwargs["labels"] = kwargs["labels"]
|
280 |
+
|
281 |
+
# if self.config.task in ["fc", "fnli"]:
|
282 |
+
# del rep_kwargs["labels"]
|
283 |
+
|
284 |
+
if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
|
285 |
+
# Add 'ignore' labels for the prepended cache inputs
|
286 |
+
pre = torch.full((kwargs["labels"].shape[0], rep_kwargs["input_ids"].shape[-1] - kwargs["labels"].shape[-1]), -100,
|
287 |
+
device=kwargs["labels"].device)
|
288 |
+
rep_kwargs["labels"] = torch.cat((pre, kwargs["labels"]), dim=-1)
|
289 |
+
|
290 |
+
return rep_kwargs
|
291 |
+
|
292 |
+
def run_classifier(self, *inputs, **kwargs):
|
293 |
+
cache_inputs = self.build_cls_cache_inputs()
|
294 |
+
test_inputs = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True)
|
295 |
+
|
296 |
+
if self.config.rep.cross_attend:
|
297 |
+
log_sim_matrix = self.crossattend_logsim_matrix(cache_inputs, test_inputs)
|
298 |
+
else:
|
299 |
+
log_sim_matrix = self.embedding_logsim_matrix(cache_inputs, test_inputs)
|
300 |
+
|
301 |
+
sims = log_sim_matrix.exp()
|
302 |
+
assert sims.max() <= 1, "Similarities shouldn't exceed 1!"
|
303 |
+
|
304 |
+
cls_sims, cls_idxs = sims.max(-1)
|
305 |
+
return cls_sims, cls_idxs, log_sim_matrix
|
306 |
+
|
307 |
+
def generate(self, *args, **kwargs):
|
308 |
+
# input_text = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True)
|
309 |
+
base_generate_fn = (
|
310 |
+
self.model.forward if type(self.model) == BertClassifier
|
311 |
+
else lambda *args, **kwargs: self.model.generate(*args, **kwargs, max_new_tokens=20)
|
312 |
+
)
|
313 |
+
cntr_generate_fn = (
|
314 |
+
self.replacement.forward if type(self.replacement) == BertClassifier
|
315 |
+
else lambda *args, **kwargs: self.replacement.generate(*args, **kwargs, max_new_tokens=20)
|
316 |
+
)
|
317 |
+
|
318 |
+
# assert len(args) == 0, "Should only pass named arguments to generate()"
|
319 |
+
if len(self.cache_inputs) > 0:
|
320 |
+
override = kwargs.get("override")
|
321 |
+
if override:
|
322 |
+
del kwargs["override"]
|
323 |
+
|
324 |
+
cls_sims, cls_idxs, _ = self.run_classifier(*args, **kwargs)
|
325 |
+
# assert cls_sims.numel() == 1
|
326 |
+
# print(f"Cache score: {cls_sims.item()} " + ("[MISS]" if cls_sims.item() < 0.5 else "[HIT]"))
|
327 |
+
use_cntr = (override == "cntr") if override is not None else (cls_sims.item() > 0.5)
|
328 |
+
if use_cntr:
|
329 |
+
rep_input = self.build_rep_input_tokens(kwargs, cls_idxs, generation=True)
|
330 |
+
kwargs["input_ids"] = rep_input["input_ids"]
|
331 |
+
kwargs["attention_mask"] = rep_input["attention_mask"]
|
332 |
+
# rep_input_text = self.replacement_tok.decode(rep_input["input_ids"][0])
|
333 |
+
# print(f"Returning counterfactual model output for '{rep_input_text}'")
|
334 |
+
if self.config.rep.freeze_cntr:
|
335 |
+
return base_generate_fn(*args, **kwargs)
|
336 |
+
else:
|
337 |
+
return cntr_generate_fn(*args, **kwargs)
|
338 |
+
|
339 |
+
# print(f"Returning base model output for '{input_text}'")
|
340 |
+
return base_generate_fn(*args, **kwargs)
|
341 |
+
|
342 |
+
def forward(self, *inputs, return_logits_only=True, eps=torch.finfo(torch.float32).eps, pos_pairs=None, **kwargs):
|
343 |
+
grad_enabled = torch.is_grad_enabled()
|
344 |
+
torch.set_grad_enabled(self.training)
|
345 |
+
|
346 |
+
# need to do soft mixing of logits if we're doing supervised training or we've specifically requested it
|
347 |
+
soft = (not self.config.rep.supervised) or self.config.rep.soft_weighting
|
348 |
+
with torch.no_grad():
|
349 |
+
if len(self.cache_inputs) == 0:
|
350 |
+
super_out = super().forward(*inputs, **kwargs).float()
|
351 |
+
torch.set_grad_enabled(grad_enabled)
|
352 |
+
return super_out
|
353 |
+
else:
|
354 |
+
base_logits = super().forward(*inputs, **kwargs).float()
|
355 |
+
if soft:
|
356 |
+
if base_logits.dim() == 3:
|
357 |
+
base_probs = base_logits.softmax(-1)
|
358 |
+
else:
|
359 |
+
base_probs = base_logits.sigmoid()
|
360 |
+
del base_logits
|
361 |
+
|
362 |
+
cls_sims, cls_idxs, cls_logits = self.run_classifier(*inputs, **kwargs)
|
363 |
+
rep_cls_inputs = self.build_rep_input_tokens(kwargs, cls_idxs)
|
364 |
+
if self.config.rep.freeze_cntr:
|
365 |
+
rep_cls_logits = _logits(super().forward(**rep_cls_inputs))
|
366 |
+
else:
|
367 |
+
rep_cls_logits = _logits(self.replacement(**rep_cls_inputs))
|
368 |
+
|
369 |
+
if pos_pairs is not None:
|
370 |
+
assert (pos_pairs[:, 0] == torch.arange(pos_pairs.shape[0], device=pos_pairs.device)).all()
|
371 |
+
gold_idxs = pos_pairs[:, 1]
|
372 |
+
# print("IDX acc:", (cls_idxs == gold_idxs).shape, (cls_idxs == gold_idxs).float().mean())
|
373 |
+
rep_gold_inputs = self.build_rep_input_tokens(kwargs, gold_idxs)
|
374 |
+
if (False and self.config.rep.freeze_cntr):
|
375 |
+
rep_gold_logits = _logits(super().forward(**rep_gold_inputs))
|
376 |
+
else:
|
377 |
+
rep_gold_logits = _logits(self.replacement(**rep_gold_inputs))
|
378 |
+
else:
|
379 |
+
rep_gold_logits = rep_cls_logits
|
380 |
+
|
381 |
+
cls_sims = cls_sims.view(-1, 1) # For (binary) classification, predictions are (B x 1)
|
382 |
+
if rep_cls_logits.dim() == 3:
|
383 |
+
cls_sims.unsqueeze_(-1) # For generation/seq2seq, predictions are (B x S x V)
|
384 |
+
|
385 |
+
stats = {
|
386 |
+
'sims/mean': cls_sims.mean().item(),
|
387 |
+
'sims/pos': (cls_sims >= 0.5).float().mean().item(),
|
388 |
+
'sims/neg': (cls_sims < 0.5).float().mean().item(),
|
389 |
+
'params/scale': self.scale.item() if self.scale is not None else 0.0,
|
390 |
+
}
|
391 |
+
|
392 |
+
if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
|
393 |
+
rep_cls_logits = rep_cls_logits[:, -kwargs["labels"].shape[-1]:, :]
|
394 |
+
|
395 |
+
if soft:
|
396 |
+
rep_weight = cls_sims
|
397 |
+
if base_probs.dim() == 3:
|
398 |
+
mixture_logits = ((1 - rep_weight) * base_probs + rep_weight * rep_cls_logits.softmax(-1) + eps).log()
|
399 |
+
else:
|
400 |
+
mixture_logits = ((1 - rep_weight) * base_probs + rep_weight * rep_cls_logits.sigmoid() + eps).log()
|
401 |
+
else:
|
402 |
+
rep_idxs = torch.where(cls_sims > 0.5)[0]
|
403 |
+
mixture_logits = base_logits
|
404 |
+
if rep_idxs.numel() > 0:
|
405 |
+
mixture_logits[rep_idxs] = rep_cls_logits[rep_idxs]
|
406 |
+
|
407 |
+
torch.set_grad_enabled(grad_enabled)
|
408 |
+
if return_logits_only:
|
409 |
+
return mixture_logits
|
410 |
+
else:
|
411 |
+
return mixture_logits, cls_logits, rep_gold_logits, stats
|
412 |
+
|
413 |
+
|
414 |
+
if __name__ == '__main__':
|
415 |
+
import types
|
416 |
+
|
417 |
+
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
|
418 |
+
|
419 |
+
config = types.SimpleNamespace()
|
420 |
+
config.model.inner_params = [
|
421 |
+
"transformer.h.9.mlp.c_fc.weight",
|
422 |
+
"transformer.h.9.mlp.c_proj.weight",
|
423 |
+
"transformer.h.10.mlp.c_fc.weight",
|
424 |
+
"transformer.h.10.mlp.c_proj.weight",
|
425 |
+
"transformer.h.11.mlp.c_fc.weight",
|
426 |
+
"transformer.h.11.mlp.c_proj.weight",
|
427 |
+
]
|
428 |
+
config.edit_lr = 0.0001
|
429 |
+
|
430 |
+
config.gtn = types.SimpleNamespace()
|
431 |
+
config.gtn.n_hidden = 1
|
432 |
+
config.gtn = config.gtn.__dict__
|
433 |
+
|
434 |
+
gtn = SERAC(model, config, lambda: copy.deepcopy(model)).cuda()
|
435 |
+
# torch.save(gtn.state_dict(), "test_state.pt")
|
436 |
+
import pdb; pdb.set_trace()
|
437 |
+
gtn.load_state_dict(torch.load("test_state.pt"))
|
438 |
+
x = torch.arange(20).view(1, 20).cuda() + 1000
|
439 |
+
orig_logits = gtn(x)
|
440 |
+
edited = gtn.edit(x, masks=torch.ones_like(x), labels=x)
|
441 |
+
post_logits = gtn(x)
|
442 |
+
|
443 |
+
assert torch.allclose(orig_logits, post_logits)
|
444 |
+
|
445 |
+
orig_param = [p for (n, p) in gtn.model.named_parameters() if n == config.model.inner_params[-1]][0]
|
446 |
+
edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]
|
447 |
+
|
448 |
+
LOG.info((orig_param - edited_param).abs().max())
|
449 |
+
edited.eval()
|
450 |
+
LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"])
|
451 |
+
edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x)
|
452 |
+
LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss)
|
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import time
|
|
|
4 |
|
5 |
EDIT_ALGS = [
|
6 |
"MEND: Model editor networks using gradient decomposition",
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import time
|
4 |
+
import algs
|
5 |
|
6 |
EDIT_ALGS = [
|
7 |
"MEND: Model editor networks using gradient decomposition",
|
editable_model.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from losses import masked_log_probs
|
4 |
+
from utils import _logits, shift_targets
|
5 |
+
|
6 |
+
|
7 |
+
class EditableModel(nn.Module):
|
8 |
+
def __init__(self, model, config, model_constructor):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.model = model
|
12 |
+
self.config = config
|
13 |
+
self.model_constructor = model_constructor
|
14 |
+
|
15 |
+
def _edit_loss_fn(pred, targ, **kwargs):
|
16 |
+
return masked_log_probs(pred, targ, shift=shift_targets(self.config), **kwargs)
|
17 |
+
self.edit_loss_fn = _edit_loss_fn
|
18 |
+
self.loc_loss_fn = _edit_loss_fn
|
19 |
+
|
20 |
+
def edit(self, batch, condition=None, detach_history=False):
|
21 |
+
raise NotImplementedError
|
22 |
+
|
23 |
+
def forward(self, *inputs, **kwargs):
|
24 |
+
return _logits(self.model(*inputs, **kwargs))
|
25 |
+
|
26 |
+
def outer_parameters(self, grouped=False):
|
27 |
+
if grouped:
|
28 |
+
return [dict(params=self.parameters(), lr=self.config.lr)]
|
29 |
+
else:
|
30 |
+
return list(self.parameters())
|
31 |
+
|
32 |
+
def generate(self, *args, **kwargs):
|
33 |
+
return self.model.generate(*args, **kwargs)
|
34 |
+
|
35 |
+
def base_loss(self, input_ids, attention_masks, label_ids):
|
36 |
+
pass
|
hooks.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import parent_module
|
2 |
+
|
3 |
+
|
4 |
+
def linear_backward_hook(mod, grad_in, grad_out):
|
5 |
+
if not hasattr(mod, "weight"):
|
6 |
+
print(f"{mod} has no weight!")
|
7 |
+
return
|
8 |
+
|
9 |
+
if hasattr(mod.weight, "__x__"):
|
10 |
+
assert len(grad_out) == 1
|
11 |
+
# mod.weight.__bgrad__ = grad_out[0].unsqueeze(-1) * mod.__x__[0].unsqueeze(-2)
|
12 |
+
mod.weight.__delta__ = grad_out[0].detach()
|
13 |
+
else:
|
14 |
+
print(f"{mod} has no __x__")
|
15 |
+
|
16 |
+
|
17 |
+
def linear_forward_hook(mod, activations, output):
|
18 |
+
assert len(activations) == 1
|
19 |
+
mod.weight.__x__ = activations[0].detach()
|
20 |
+
|
21 |
+
|
22 |
+
def hook_model(model, pnames):
|
23 |
+
handles = []
|
24 |
+
for m in [parent_module(model, pname) for pname in pnames]:
|
25 |
+
handles.append(m.register_full_backward_hook(linear_backward_hook))
|
26 |
+
handles.append(m.register_forward_hook(linear_forward_hook))
|
27 |
+
|
28 |
+
model.handles = handles
|
losses.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from metrics import es_sentiment
|
4 |
+
from utils import gather_log_probs, mask_hf_labels, masked_mean
|
5 |
+
|
6 |
+
|
7 |
+
def balanced_bce(log_probs, labels, eps=torch.finfo(torch.float32).eps):
|
8 |
+
assert labels.max() <= 1
|
9 |
+
assert labels.min() >= 0
|
10 |
+
|
11 |
+
pos_losses = -log_probs[labels == 1]
|
12 |
+
neg_probs = 1 - log_probs.exp()
|
13 |
+
neg_probs[neg_probs == 0] += eps # for numerical stability
|
14 |
+
neg_losses = -neg_probs.log()[labels == 0]
|
15 |
+
pos_loss = pos_losses.mean() if pos_losses.numel() > 0 else 0
|
16 |
+
neg_loss = neg_losses.mean() if neg_losses.numel() > 0 else 0
|
17 |
+
|
18 |
+
return pos_loss + neg_loss
|
19 |
+
|
20 |
+
|
21 |
+
def kl_loc_loss(pre, post, mask=None):
|
22 |
+
pre = pre.to(torch.float32)
|
23 |
+
post = post.to(torch.float32)
|
24 |
+
|
25 |
+
sequence = pre.dim() == 3
|
26 |
+
pre_ = pre.view(-1, pre.shape[-1])
|
27 |
+
post_ = post.view(pre_.shape)
|
28 |
+
assert pre_.shape[0] == post_.shape[0]
|
29 |
+
|
30 |
+
if not sequence:
|
31 |
+
if pre_.shape[-1] == 1: # No masking needed for binary classification
|
32 |
+
return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + (
|
33 |
+
(-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post))
|
34 |
+
).mean()
|
35 |
+
else: # We have sequences of predictions; masking needed
|
36 |
+
if pre_.shape[-1] > 1:
|
37 |
+
assert mask is not None
|
38 |
+
mask_ = mask.view(pre_.shape[0])
|
39 |
+
kl = (pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))).sum(-1)
|
40 |
+
return (kl * mask_).sum() / mask_.sum()
|
41 |
+
|
42 |
+
raise NotImplementedError
|
43 |
+
|
44 |
+
|
45 |
+
def binary_log_probs(pred, targ, should_reduce=True):
|
46 |
+
assert targ.max() <= 1
|
47 |
+
assert targ.min() >= 0
|
48 |
+
neg_mask = torch.ones_like(pred)
|
49 |
+
neg_mask[targ == 0] *= -1
|
50 |
+
pred = pred * neg_mask
|
51 |
+
log_probs = F.logsigmoid(pred)
|
52 |
+
acc = (log_probs.exp() > 0.5).float()
|
53 |
+
if should_reduce:
|
54 |
+
acc = acc.mean()
|
55 |
+
return {
|
56 |
+
"acc": acc,
|
57 |
+
"log_prob": log_probs.mean(),
|
58 |
+
"prob": log_probs.exp().mean(),
|
59 |
+
"nll": -log_probs.mean(),
|
60 |
+
"n_tokens": log_probs.shape[0]
|
61 |
+
}
|
62 |
+
|
63 |
+
|
64 |
+
def multiclass_log_probs(
|
65 |
+
pred,
|
66 |
+
raw_targets,
|
67 |
+
shift=True,
|
68 |
+
eps=torch.finfo(torch.float32).eps,
|
69 |
+
should_reduce=True,
|
70 |
+
**kwargs,
|
71 |
+
):
|
72 |
+
NULL_TOKEN = 0 # a placeholder used for masked target locations
|
73 |
+
|
74 |
+
pred = pred.clone()
|
75 |
+
mask, targ = mask_hf_labels(raw_targets)
|
76 |
+
if shift and pred.dim() == 3: # Dealing with sequences
|
77 |
+
pred = pred[:, :-1] # Remove last prediction in sequence
|
78 |
+
targ = targ[:, 1:] # Shift to align predictions and targets
|
79 |
+
|
80 |
+
unmasked_log_probs = gather_log_probs(pred, targ)
|
81 |
+
|
82 |
+
pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
|
83 |
+
correct = pred_ids == targ
|
84 |
+
if pred.dim() == 3:
|
85 |
+
correct = (pred_ids == targ).all(-1) # We want to get the whole sequence right
|
86 |
+
acc = correct.float()
|
87 |
+
if should_reduce:
|
88 |
+
acc = acc.mean()
|
89 |
+
|
90 |
+
if "inner_sent" in kwargs:
|
91 |
+
# Only use outer samples with the same sentiment as the inner sample
|
92 |
+
same_sent_mask = torch.tensor([i == o for i, o in zip(kwargs["inner_sent"], kwargs["outer_sent"])], device=pred.device)
|
93 |
+
good_mask = mask * same_sent_mask.unsqueeze(-1)
|
94 |
+
bad_mask = mask * (~same_sent_mask.unsqueeze(-1))
|
95 |
+
|
96 |
+
good_log_prob = masked_mean(unmasked_log_probs, good_mask)
|
97 |
+
bad_log_prob = masked_mean((1 - unmasked_log_probs.exp() + eps).log(), bad_mask)
|
98 |
+
|
99 |
+
n_tokens = good_mask.float().sum()
|
100 |
+
avg_log_prob = good_log_prob
|
101 |
+
|
102 |
+
if kwargs["unlikelihood"]:
|
103 |
+
nll = -good_log_prob - bad_log_prob
|
104 |
+
else:
|
105 |
+
nll = -good_log_prob
|
106 |
+
else:
|
107 |
+
n_tokens = mask.float().sum()
|
108 |
+
avg_log_prob = (unmasked_log_probs * mask.float()).sum() / n_tokens
|
109 |
+
nll = -avg_log_prob
|
110 |
+
|
111 |
+
info_dict = {
|
112 |
+
"acc": acc,
|
113 |
+
"log_prob": avg_log_prob,
|
114 |
+
"prob": avg_log_prob.exp(),
|
115 |
+
"n_tokens": n_tokens,
|
116 |
+
"nll": nll
|
117 |
+
}
|
118 |
+
|
119 |
+
if "inner_sent" in kwargs:
|
120 |
+
info_dict.update(es_sentiment(kwargs["pre_edit_logits"],
|
121 |
+
kwargs["post_edit_logits"],
|
122 |
+
raw_targets,
|
123 |
+
same_sent_mask))
|
124 |
+
|
125 |
+
return info_dict
|
126 |
+
|
127 |
+
|
128 |
+
def masked_log_probs(pred, targ, shift=True, **kwargs):
|
129 |
+
pred = pred.to(torch.float32)
|
130 |
+
|
131 |
+
if not (pred.dim() == 2 or pred.dim() == 3):
|
132 |
+
raise RuntimeError(f"Expected pred to have 2 or 3 dimensions, got {pred.shape}")
|
133 |
+
|
134 |
+
if pred.shape[-1] == 1:
|
135 |
+
should_reduce = True
|
136 |
+
if "should_reduce" in kwargs:
|
137 |
+
should_reduce = kwargs["should_reduce"]
|
138 |
+
return binary_log_probs(pred, targ, should_reduce=should_reduce)
|
139 |
+
else:
|
140 |
+
return multiclass_log_probs(pred, targ, shift=shift, **kwargs)
|
141 |
+
|
142 |
+
|
143 |
+
def test_masked_log_probs():
|
144 |
+
print()
|
145 |
+
N = 10000
|
146 |
+
pred = torch.randn(10, 15, N)
|
147 |
+
targ = torch.randint(0, N, (10, 15))
|
148 |
+
true_pred = pred.clone()
|
149 |
+
true_pred.scatter_(2, targ.unsqueeze(-1), 5)
|
150 |
+
true_pred = true_pred.roll(-1, 1)
|
151 |
+
|
152 |
+
half_pred = true_pred.clone()
|
153 |
+
mask = torch.arange(10) % 2 == 0
|
154 |
+
half_pred[mask] = pred[mask]
|
155 |
+
|
156 |
+
pred_ = pred.clone()
|
157 |
+
true_pred_ = true_pred.clone()
|
158 |
+
half_pred_ = half_pred.clone()
|
159 |
+
targ_ = targ.clone()
|
160 |
+
|
161 |
+
print(masked_log_probs(pred, targ, return_acc=True))
|
162 |
+
print(masked_log_probs(true_pred, targ, return_acc=True))
|
163 |
+
print(masked_log_probs(half_pred, targ, return_acc=True))
|
164 |
+
|
165 |
+
assert (pred == pred_).all()
|
166 |
+
assert (targ == targ_).all()
|
167 |
+
assert (half_pred == half_pred_).all()
|
168 |
+
assert (true_pred == true_pred_).all()
|
169 |
+
|
170 |
+
import pdb; pdb.set_trace()
|
171 |
+
|
172 |
+
pred = torch.randn(1000, 15, 1)
|
173 |
+
targ = torch.randint(0, 2, (1000, 15))
|
174 |
+
|
175 |
+
print(masked_log_probs(pred, targ, return_acc=True))
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
torch.manual_seed(0)
|
180 |
+
|
181 |
+
test_masked_log_probs()
|
metrics.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from utils import gather_log_probs, mask_hf_labels, masked_mean
|
3 |
+
|
4 |
+
|
5 |
+
def es_sentiment(pre_logits, post_logits, raw_targets, same_sent_mask, NULL_TOKEN=0):
|
6 |
+
with torch.no_grad():
|
7 |
+
mask, targ = mask_hf_labels(raw_targets)
|
8 |
+
pos_mask = same_sent_mask.unsqueeze(-1) * mask
|
9 |
+
neg_mask = (~same_sent_mask).unsqueeze(-1) * mask
|
10 |
+
|
11 |
+
# Compute log likelihoods of pos/neg samples
|
12 |
+
pre_edit_token_log_probs = gather_log_probs(pre_logits, targ)
|
13 |
+
post_edit_token_log_probs = gather_log_probs(post_logits, targ)
|
14 |
+
|
15 |
+
mean_pos_pre = masked_mean(pre_edit_token_log_probs, pos_mask)
|
16 |
+
mean_pos_post = masked_mean(post_edit_token_log_probs, pos_mask)
|
17 |
+
mean_neg_post = masked_mean(post_edit_token_log_probs, neg_mask)
|
18 |
+
|
19 |
+
z_sent = (mean_pos_post - mean_neg_post).sigmoid()
|
20 |
+
z_topic_raw = (mean_pos_post - mean_pos_pre).exp()
|
21 |
+
z_topic = min(1, z_topic_raw)
|
22 |
+
|
23 |
+
es_sent = z_sent * z_topic
|
24 |
+
|
25 |
+
return {
|
26 |
+
"acc_sent": es_sent,
|
27 |
+
"z_sent": z_sent,
|
28 |
+
"z_topic": z_topic,
|
29 |
+
"z_topic_raw": z_topic_raw,
|
30 |
+
"correct_probs": mean_pos_post,
|
31 |
+
"wrong_probs": mean_neg_post,
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
# DEPRECATED
|
36 |
+
def sent_success(pre_edit_probs, post_edit_probs, pos_mask, eps=torch.finfo(torch.float32).eps, batch_size=20):
|
37 |
+
assert False, "No longer used"
|
38 |
+
# content_score = post_edit_probs[pos_mask].prod() ** (1/pos_mask.sum()) / (pre_edit_probs[pos_mask]. + eps)
|
39 |
+
post_pos_avg = post_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum())
|
40 |
+
pre_pos_avg = pre_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum())
|
41 |
+
content_score = post_pos_avg / (pre_pos_avg + eps)
|
42 |
+
z_content = min(1., content_score)
|
43 |
+
|
44 |
+
# compute z_sent through a weighting objective
|
45 |
+
# normalized_probs = post_edit_probs / (post_edit_probs.sum() + eps)
|
46 |
+
# balancing_factor = 0.5 * ((~pos_mask).float().sum() / pos_mask.float().sum() + 1)
|
47 |
+
# z_sent_weight = balancing_factor * normalized_probs.dot(pos_mask.float())
|
48 |
+
post_neg_avg = post_edit_probs[~pos_mask].prod() ** (1 / (~pos_mask).sum())
|
49 |
+
neg_over_pos = post_neg_avg / (eps + post_pos_avg)
|
50 |
+
z_sent_weight = 1 / (1 + neg_over_pos)
|
51 |
+
|
52 |
+
# compute z_sent through a ranking objective
|
53 |
+
batch_mask = pos_mask.view(-1, batch_size).long()
|
54 |
+
sort_idxs = post_edit_probs.view(-1, batch_size).sort(-1, descending=True).indices
|
55 |
+
ranked_mask = batch_mask.gather(1, sort_idxs)
|
56 |
+
true_mask = batch_mask.sort(-1, descending=True).values
|
57 |
+
z_sent_rank = (ranked_mask == true_mask).float().mean()
|
58 |
+
|
59 |
+
# compute the final success scores
|
60 |
+
weight_success = (z_content * z_sent_weight) ** 0.5
|
61 |
+
rank_success = (z_content * z_sent_rank) ** 0.5
|
62 |
+
|
63 |
+
correct_probs = post_edit_probs[pos_mask].mean()
|
64 |
+
wrong_probs = post_edit_probs[~pos_mask].mean()
|
65 |
+
|
66 |
+
return {
|
67 |
+
"acc_weight": weight_success,
|
68 |
+
"acc_rank": rank_success,
|
69 |
+
"rank_score": z_sent_rank,
|
70 |
+
"weight_score": z_sent_weight,
|
71 |
+
"content_score": content_score,
|
72 |
+
"post_edit_probs": post_edit_probs.sum(),
|
73 |
+
"pre_edit_probs": pre_edit_probs.sum(),
|
74 |
+
"correct_probs": correct_probs,
|
75 |
+
"wrong_probs": wrong_probs
|
76 |
+
}
|
77 |
+
|
78 |
+
|
79 |
+
# def sent_retain(pre_logits, post_logits, sent_mask, batch_size=20, eps=torch.finfo(torch.float32).eps):
|
80 |
+
# pre_log_probs = pre_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1)
|
81 |
+
# post_log_probs = post_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1)
|
82 |
+
|
83 |
+
# pre_batch = pre_probs.view(-1, batch_size)
|
84 |
+
# post_batch = post_probs.view(-1, batch_size)
|
85 |
+
# mask_batch = sent_mask.view(-1, batch_size)
|
86 |
+
|
87 |
+
# stats = []
|
88 |
+
# for pre, post, mask in zip(pre_batch, post_batch, mask_batch):
|
89 |
+
# avg_pre = pre.prod() ** (1 / pre.numel())
|
90 |
+
# avg_post = post.prod() ** (1 / post.numel())
|
91 |
+
# z_avg = min(avg_pre / avg_post, avg_post / avg_pre)
|
92 |
+
|
93 |
+
# post_neg_avg = post[~mask].prod() ** (1 / (~mask).sum())
|
94 |
+
# post_pos_avg = post[mask].prod() ** (1 / mask.sum())
|
95 |
+
|
96 |
+
# pre_neg_avg = pre[~mask].prod() ** (1 / (~mask).sum())
|
97 |
+
# pre_pos_avg = pre[mask].prod() ** (1 / mask.sum())
|
98 |
+
|
99 |
+
# post_neg_over_pos = post_neg_avg / (eps + post_pos_avg)
|
100 |
+
# pre_neg_over_pos = pre_neg_avg / (eps + pre_pos_avg)
|
101 |
+
# z_post = 1 / (1 + post_neg_over_pos)
|
102 |
+
# z_pre = 1 / (1 + pre_neg_over_pos)
|
103 |
+
|
104 |
+
# z_sent = min(z_post / z_pre, z_pre / z_post)
|
105 |
+
|
106 |
+
# stats.append((z_avg * z_sent) ** 0.5)
|
107 |
+
|
108 |
+
# return sum(stats) / len(stats)
|
109 |
+
|
110 |
+
|
111 |
+
# For zsRE and F-NLI
|
112 |
+
def retain_rate(pre_logits, post_logits, mask=None):
|
113 |
+
if pre_logits.shape[-1] == 1:
|
114 |
+
pre_logits = pre_logits.squeeze(-1)
|
115 |
+
if post_logits.shape[-1] == 1:
|
116 |
+
post_logits = post_logits.squeeze(-1)
|
117 |
+
|
118 |
+
assert pre_logits.shape == post_logits.shape
|
119 |
+
assert pre_logits.shape[0] == mask.shape[0]
|
120 |
+
|
121 |
+
if pre_logits.dim() == 1:
|
122 |
+
# binary classification
|
123 |
+
pre_preds = pre_logits > 0
|
124 |
+
post_preds = post_logits > 0
|
125 |
+
retain = (pre_preds == post_preds).float().mean()
|
126 |
+
elif pre_logits.dim() == 3:
|
127 |
+
# sequence modeling
|
128 |
+
pre_preds = pre_logits.argmax(-1)
|
129 |
+
post_preds = post_logits.argmax(-1)
|
130 |
+
match = (pre_preds == post_preds) * mask
|
131 |
+
retain = (match.sum(-1) == mask.sum(-1)).float().mean()
|
132 |
+
else:
|
133 |
+
raise NotImplementedError
|
134 |
+
|
135 |
+
return retain.item()
|
models.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import re
|
5 |
+
import logging
|
6 |
+
from nn import FixableDropout
|
7 |
+
from utils import scr
|
8 |
+
|
9 |
+
|
10 |
+
LOG = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class CastModule(nn.Module):
|
14 |
+
def __init__(self, module: nn.Module, in_cast: torch.dtype = torch.float32, out_cast: torch.dtype = None):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.underlying = module
|
18 |
+
self.in_cast = in_cast
|
19 |
+
self.out_cast = out_cast
|
20 |
+
|
21 |
+
def cast(self, obj, dtype):
|
22 |
+
if dtype is None:
|
23 |
+
return obj
|
24 |
+
|
25 |
+
if isinstance(obj, torch.Tensor):
|
26 |
+
return obj.to(dtype)
|
27 |
+
else:
|
28 |
+
return obj
|
29 |
+
|
30 |
+
def forward(self, *args, **kwargs):
|
31 |
+
args = tuple(self.cast(a, self.in_cast) for a in args)
|
32 |
+
kwargs = {k: self.cast(v, self.in_cast) for k, v in kwargs.items()}
|
33 |
+
outputs = self.underlying(*args, **kwargs)
|
34 |
+
if isinstance(outputs, torch.Tensor):
|
35 |
+
outputs = self.cast(outputs, self.out_cast)
|
36 |
+
elif isinstance(outputs, tuple):
|
37 |
+
outputs = tuple(self.cast(o, self.out_cast) for o in outputs)
|
38 |
+
else:
|
39 |
+
raise RuntimeError(f"Not sure how to cast type {type(outputs)}")
|
40 |
+
return outputs
|
41 |
+
|
42 |
+
def extra_repr(self):
|
43 |
+
return f"in_cast: {self.in_cast}\nout_cast: {self.out_cast}"
|
44 |
+
|
45 |
+
|
46 |
+
class BertClassifier(torch.nn.Module):
|
47 |
+
def __init__(self, model_name, hidden_dim=768):
|
48 |
+
super().__init__()
|
49 |
+
if model_name.startswith("bert"):
|
50 |
+
self.model = transformers.BertModel.from_pretrained(model_name, cache_dir=scr())
|
51 |
+
else:
|
52 |
+
self.model = transformers.AutoModel.from_pretrained(model_name, cache_dir=scr())
|
53 |
+
self.classifier = torch.nn.Linear(hidden_dim, 1)
|
54 |
+
|
55 |
+
@property
|
56 |
+
def config(self):
|
57 |
+
return self.model.config
|
58 |
+
|
59 |
+
def forward(self, *args, **kwargs):
|
60 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "labels"}
|
61 |
+
model_output = self.model(*args, **filtered_kwargs)
|
62 |
+
if "pooler_output" in model_output.keys():
|
63 |
+
pred = self.classifier(model_output.pooler_output)
|
64 |
+
else:
|
65 |
+
pred = self.classifier(model_output.last_hidden_state[:, 0])
|
66 |
+
|
67 |
+
if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]:
|
68 |
+
last_hidden_state = model_output.last_hidden_state
|
69 |
+
return pred, last_hidden_state
|
70 |
+
else:
|
71 |
+
return pred
|
72 |
+
|
73 |
+
|
74 |
+
def replace_dropout(model):
|
75 |
+
for m in model.modules():
|
76 |
+
for n, c in m.named_children():
|
77 |
+
if isinstance(c, nn.Dropout):
|
78 |
+
setattr(m, n, FixableDropout(c.p))
|
79 |
+
|
80 |
+
def resample(m, seed=None):
|
81 |
+
for c in m.children():
|
82 |
+
if hasattr(c, "resample"):
|
83 |
+
c.resample(seed)
|
84 |
+
else:
|
85 |
+
resample(c, seed)
|
86 |
+
|
87 |
+
model.resample_dropout = resample.__get__(model)
|
88 |
+
|
89 |
+
|
90 |
+
def get_model(config):
|
91 |
+
if config.model.class_name == "BertClassifier":
|
92 |
+
model = BertClassifier(config.model.name)
|
93 |
+
else:
|
94 |
+
ModelClass = getattr(transformers, config.model.class_name)
|
95 |
+
LOG.info(f"Loading model class {ModelClass} with name {config.model.name} from cache dir {scr()}")
|
96 |
+
model = ModelClass.from_pretrained(config.model.name, cache_dir=scr())
|
97 |
+
|
98 |
+
if config.model.pt is not None:
|
99 |
+
LOG.info(f"Loading model initialization from {config.model.pt}")
|
100 |
+
state_dict = torch.load(config.model.pt, map_location="cpu")
|
101 |
+
|
102 |
+
try:
|
103 |
+
model.load_state_dict(state_dict)
|
104 |
+
except RuntimeError:
|
105 |
+
LOG.info("Default load failed; stripping prefix and trying again.")
|
106 |
+
state_dict = {re.sub("^model.", "", k): v for k, v in state_dict.items()}
|
107 |
+
|
108 |
+
model.load_state_dict(state_dict)
|
109 |
+
|
110 |
+
LOG.info("Loaded model initialization")
|
111 |
+
|
112 |
+
if config.dropout is not None:
|
113 |
+
n_reset = 0
|
114 |
+
for m in model.modules():
|
115 |
+
if isinstance(m, nn.Dropout):
|
116 |
+
m.p = config.dropout
|
117 |
+
n_reset += 1
|
118 |
+
|
119 |
+
if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout
|
120 |
+
if isinstance(m.dropout, float):
|
121 |
+
m.dropout = config.dropout
|
122 |
+
n_reset += 1
|
123 |
+
|
124 |
+
if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout
|
125 |
+
if isinstance(m.activation_dropout, float):
|
126 |
+
m.activation_dropout = config.dropout
|
127 |
+
n_reset += 1
|
128 |
+
|
129 |
+
LOG.info(f"Set {n_reset} dropout modules to p={config.dropout}")
|
130 |
+
|
131 |
+
param_names = [n for n, _ in model.named_parameters()]
|
132 |
+
bad_inner_params = [p for p in config.model.inner_params if p not in param_names]
|
133 |
+
if len(bad_inner_params) != 0:
|
134 |
+
raise ValueError(f"Params {bad_inner_params} do not exist in model of type {type(model)}.")
|
135 |
+
|
136 |
+
if config.no_grad_layers is not None:
|
137 |
+
if config.half:
|
138 |
+
model.bfloat16()
|
139 |
+
|
140 |
+
def upcast(mod):
|
141 |
+
modlist = None
|
142 |
+
for child in mod.children():
|
143 |
+
if isinstance(child, nn.ModuleList):
|
144 |
+
assert modlist is None, f"Found multiple modlists for {mod}"
|
145 |
+
modlist = child
|
146 |
+
if modlist is None:
|
147 |
+
raise RuntimeError("Couldn't find a ModuleList child")
|
148 |
+
|
149 |
+
LOG.info(f"Setting {len(modlist) - config.no_grad_layers} modules to full precision, with autocasting")
|
150 |
+
modlist[config.no_grad_layers:].to(torch.float32)
|
151 |
+
modlist[config.no_grad_layers] = CastModule(modlist[config.no_grad_layers])
|
152 |
+
modlist[-1] = CastModule(modlist[-1], in_cast=torch.float32, out_cast=torch.bfloat16)
|
153 |
+
|
154 |
+
parents = []
|
155 |
+
if hasattr(model, "transformer"):
|
156 |
+
parents.append(model.transformer)
|
157 |
+
if hasattr(model, "encoder"):
|
158 |
+
parents.append(model.encoder)
|
159 |
+
if hasattr(model, "decoder"):
|
160 |
+
parents.append(model.decoder)
|
161 |
+
if hasattr(model, "model"):
|
162 |
+
parents.extend([model.model.encoder, model.model.decoder])
|
163 |
+
|
164 |
+
for t in parents:
|
165 |
+
t.no_grad_layers = config.no_grad_layers
|
166 |
+
if config.half and config.alg != "rep":
|
167 |
+
upcast(t)
|
168 |
+
|
169 |
+
if config.half and config.alg != "rep":
|
170 |
+
idxs = []
|
171 |
+
for p in config.model.inner_params:
|
172 |
+
for comp in p.split('.'):
|
173 |
+
if comp.isdigit():
|
174 |
+
idxs.append(int(comp))
|
175 |
+
max_idx, min_idx = str(max(idxs)), str(config.no_grad_layers)
|
176 |
+
for pidx, p in enumerate(config.model.inner_params):
|
177 |
+
comps = p.split('.')
|
178 |
+
if max_idx in comps or min_idx in comps:
|
179 |
+
index = comps.index(max_idx) if max_idx in comps else comps.index(min_idx)
|
180 |
+
comps.insert(index + 1, 'underlying')
|
181 |
+
new_p = '.'.join(comps)
|
182 |
+
LOG.info(f"Replacing config.model.inner_params[{pidx}] '{p}' -> '{new_p}'")
|
183 |
+
config.model.inner_params[pidx] = new_p
|
184 |
+
|
185 |
+
return model
|
186 |
+
|
187 |
+
|
188 |
+
def get_tokenizer(config):
|
189 |
+
tok_name = config.model.tokenizer_name if config.model.tokenizer_name is not None else config.model.name
|
190 |
+
return getattr(transformers, config.model.tokenizer_class).from_pretrained(tok_name, cache_dir=scr())
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == '__main__':
|
194 |
+
m = BertClassifier("bert-base-uncased")
|
195 |
+
m(torch.arange(5)[None, :])
|
196 |
+
import pdb; pdb.set_trace()
|
nn.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import time
|
6 |
+
|
7 |
+
from utils import factorization
|
8 |
+
|
9 |
+
LOG = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class FixableDropout(nn.Module):
|
13 |
+
def __init__(self, p: float):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.p = p
|
17 |
+
self.mask_cache = {}
|
18 |
+
self.seed = 0
|
19 |
+
|
20 |
+
def resample(self, seed=None):
|
21 |
+
if seed is None:
|
22 |
+
seed = int(time.time() * 1e6)
|
23 |
+
self.mask_cache = {}
|
24 |
+
self.seed = seed
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
if self.training:
|
28 |
+
if x.shape not in self.mask_cache:
|
29 |
+
generator = torch.Generator(x.device).manual_seed(self.seed)
|
30 |
+
self.mask_cache[x.shape] = torch.bernoulli(
|
31 |
+
torch.full_like(x, 1 - self.p), generator=generator
|
32 |
+
).bool()
|
33 |
+
self.should_resample = False
|
34 |
+
|
35 |
+
x = (self.mask_cache[x.shape] * x) / (1 - self.p)
|
36 |
+
|
37 |
+
return x
|
38 |
+
|
39 |
+
def extra_repr(self) -> str:
|
40 |
+
return f"p={self.p}"
|
41 |
+
|
42 |
+
|
43 |
+
class ActMLP(nn.Module):
|
44 |
+
def __init__(self, hidden_dim, n_hidden):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
self.mlp = MLP(1, 1, hidden_dim, n_hidden, init="id")
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return self.mlp(x.view(-1, 1)).view(x.shape)
|
51 |
+
|
52 |
+
|
53 |
+
class LightIDMLP(nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
indim: int,
|
57 |
+
outdim: int,
|
58 |
+
hidden_dim: int,
|
59 |
+
n_hidden: int,
|
60 |
+
init: str = None,
|
61 |
+
act: str = None,
|
62 |
+
rank: int = None,
|
63 |
+
):
|
64 |
+
super().__init__()
|
65 |
+
LOG.info(f"Building LightIDMLP {[indim] + [rank] + [indim]}")
|
66 |
+
self.layer1 = nn.Linear(indim, rank)
|
67 |
+
self.layer2 = nn.Linear(rank, indim)
|
68 |
+
self.layer2.weight.data[:] = 0
|
69 |
+
self.layer2.bias = None
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
h = self.layer1(x).relu()
|
73 |
+
return x + self.layer2(h)
|
74 |
+
|
75 |
+
|
76 |
+
class IDMLP(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
indim: int,
|
80 |
+
outdim: int,
|
81 |
+
hidden_dim: int,
|
82 |
+
n_hidden: int,
|
83 |
+
init: str = None,
|
84 |
+
act: str = None,
|
85 |
+
rank: int = None,
|
86 |
+
n_modes: int = None
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
LOG.info(f"Building IDMLP ({init}) {[indim] * (n_hidden + 2)}")
|
90 |
+
self.layers = nn.ModuleList(
|
91 |
+
[
|
92 |
+
LRLinear(indim, indim, rank=rank, relu=idx < n_hidden, init=init, n_modes=n_modes)
|
93 |
+
for idx in range(n_hidden + 1)
|
94 |
+
]
|
95 |
+
)
|
96 |
+
|
97 |
+
def forward(self, x, mode=None):
|
98 |
+
for layer in self.layers:
|
99 |
+
x = layer(x, mode=mode)
|
100 |
+
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class LatentIDMLP(nn.Module):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
indim: int,
|
108 |
+
outdim: int,
|
109 |
+
hidden_dim: int,
|
110 |
+
n_hidden: int,
|
111 |
+
init: str = None,
|
112 |
+
act: str = None,
|
113 |
+
rank: int = None,
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
LOG.info(f"Building Latent IDMLP ({init}) {[indim] * (n_hidden + 2)}")
|
117 |
+
|
118 |
+
self.layers = nn.ModuleList()
|
119 |
+
self.layers.append(nn.Linear(indim, rank))
|
120 |
+
for _ in range(n_hidden - 1):
|
121 |
+
self.layers.append(nn.Linear(rank, rank))
|
122 |
+
self.layers.append(nn.Linear(rank, outdim))
|
123 |
+
|
124 |
+
for layer in self.layers[:-1]:
|
125 |
+
nn.init.xavier_normal_(layer.weight.data)
|
126 |
+
|
127 |
+
if init == "id":
|
128 |
+
self.layers[-1].weight.data.zero_()
|
129 |
+
self.layers[-1].bias.data.zero_()
|
130 |
+
|
131 |
+
self.init = init
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
out = x
|
135 |
+
for layer in self.layers[:-1]:
|
136 |
+
out = layer(out).relu()
|
137 |
+
|
138 |
+
out = self.layers[-1](out)
|
139 |
+
if self.init == "id":
|
140 |
+
return out + x
|
141 |
+
else:
|
142 |
+
return out
|
143 |
+
|
144 |
+
|
145 |
+
class KLinear(nn.Module):
|
146 |
+
def __init__(self, inf, outf, pfrac=0.05, symmetric=True, zero_init: bool = True):
|
147 |
+
super().__init__()
|
148 |
+
|
149 |
+
self.inf = inf
|
150 |
+
|
151 |
+
in_fact = factorization(inf)
|
152 |
+
out_fact = factorization(outf)
|
153 |
+
|
154 |
+
total_params = 0
|
155 |
+
self.a, self.b = nn.ParameterList(), nn.ParameterList()
|
156 |
+
for (i1, i2), (o1, o2) in zip(reversed(in_fact), reversed(out_fact)):
|
157 |
+
new_params = (o1 * i1 + o2 * i2) * (2 if symmetric else 1)
|
158 |
+
if (total_params + new_params) / (inf * outf) > pfrac and len(self.a) > 0:
|
159 |
+
break
|
160 |
+
total_params += new_params
|
161 |
+
|
162 |
+
self.a.append(nn.Parameter(torch.empty(o1, i1)))
|
163 |
+
if symmetric:
|
164 |
+
self.a.append(nn.Parameter(torch.empty(o2, i2)))
|
165 |
+
|
166 |
+
self.b.append(nn.Parameter(torch.empty(o2, i2)))
|
167 |
+
if symmetric:
|
168 |
+
self.b.append(nn.Parameter(torch.empty(o1, i1)))
|
169 |
+
|
170 |
+
assert self.a[-1].kron(self.b[-1]).shape == (outf, inf)
|
171 |
+
|
172 |
+
for factor in self.a:
|
173 |
+
nn.init.kaiming_normal_(factor.data)
|
174 |
+
for factor in self.b:
|
175 |
+
if zero_init:
|
176 |
+
factor.data.zero_()
|
177 |
+
else:
|
178 |
+
nn.init.kaiming_normal_(factor.data)
|
179 |
+
|
180 |
+
print(f"Created ({symmetric}) k-layer using {total_params/(outf*inf):.3f} params, {len(self.a)} comps")
|
181 |
+
self.bias = nn.Parameter(torch.zeros(outf))
|
182 |
+
|
183 |
+
def forward(self, x):
|
184 |
+
assert x.shape[-1] == self.inf, f"Expected input with {self.inf} dimensions, got {x.shape}"
|
185 |
+
w = sum([a.kron(b) for a, b in zip(self.a, self.b)]) / (2 * len(self.a) ** 0.5)
|
186 |
+
y = w @ x.T
|
187 |
+
if self.bias is not None:
|
188 |
+
y = y + self.bias
|
189 |
+
return y
|
190 |
+
|
191 |
+
|
192 |
+
class LRLinear(nn.Module):
|
193 |
+
def __init__(self, inf, outf, rank: int = None, relu=False, init="id", n_modes=None):
|
194 |
+
super().__init__()
|
195 |
+
|
196 |
+
mid_dim = min(rank, inf)
|
197 |
+
if init == "id":
|
198 |
+
self.u = nn.Parameter(torch.zeros(outf, mid_dim))
|
199 |
+
self.v = nn.Parameter(torch.randn(mid_dim, inf))
|
200 |
+
elif init == "xavier":
|
201 |
+
self.u = nn.Parameter(torch.empty(outf, mid_dim))
|
202 |
+
self.v = nn.Parameter(torch.empty(mid_dim, inf))
|
203 |
+
nn.init.xavier_uniform_(self.u.data, gain=nn.init.calculate_gain("relu"))
|
204 |
+
nn.init.xavier_uniform_(self.v.data, gain=1.0)
|
205 |
+
else:
|
206 |
+
raise ValueError(f"Unrecognized initialization {init}")
|
207 |
+
|
208 |
+
if n_modes is not None:
|
209 |
+
self.mode_shift = nn.Embedding(n_modes, outf)
|
210 |
+
self.mode_shift.weight.data.zero_()
|
211 |
+
self.mode_scale = nn.Embedding(n_modes, outf)
|
212 |
+
self.mode_scale.weight.data.fill_(1)
|
213 |
+
|
214 |
+
self.n_modes = n_modes
|
215 |
+
self.bias = nn.Parameter(torch.zeros(outf))
|
216 |
+
self.inf = inf
|
217 |
+
self.init = init
|
218 |
+
|
219 |
+
def forward(self, x, mode=None):
|
220 |
+
if mode is not None:
|
221 |
+
assert self.n_modes is not None, "Linear got a mode but wasn't initialized for it"
|
222 |
+
assert mode < self.n_modes, f"Input mode {mode} outside of range {self.n_modes}"
|
223 |
+
assert x.shape[-1] == self.inf, f"Input wrong dim ({x.shape}, {self.inf})"
|
224 |
+
|
225 |
+
pre_act = (self.u @ (self.v @ x.T)).T
|
226 |
+
if self.bias is not None:
|
227 |
+
pre_act += self.bias
|
228 |
+
|
229 |
+
if mode is not None:
|
230 |
+
if not isinstance(mode, torch.Tensor):
|
231 |
+
mode = torch.tensor(mode).to(x.device)
|
232 |
+
scale, shift = self.mode_scale(mode), self.mode_shift(mode)
|
233 |
+
pre_act = pre_act * scale + shift
|
234 |
+
|
235 |
+
# need clamp instead of relu so gradient at 0 isn't 0
|
236 |
+
acts = pre_act.clamp(min=0)
|
237 |
+
if self.init == "id":
|
238 |
+
return acts + x
|
239 |
+
else:
|
240 |
+
return acts
|
241 |
+
|
242 |
+
|
243 |
+
class MLP(nn.Module):
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
indim: int,
|
247 |
+
outdim: int,
|
248 |
+
hidden_dim: int,
|
249 |
+
n_hidden: int,
|
250 |
+
init: str = "xavier_uniform",
|
251 |
+
act: str = "relu",
|
252 |
+
rank: int = None,
|
253 |
+
):
|
254 |
+
super().__init__()
|
255 |
+
|
256 |
+
self.init = init
|
257 |
+
|
258 |
+
if act == "relu":
|
259 |
+
self.act = nn.ReLU()
|
260 |
+
elif act == "learned":
|
261 |
+
self.act = ActMLP(10, 1)
|
262 |
+
else:
|
263 |
+
raise ValueError(f"Unrecognized activation function '{act}'")
|
264 |
+
|
265 |
+
if hidden_dim is None:
|
266 |
+
hidden_dim = outdim * 2
|
267 |
+
|
268 |
+
if init.startswith("id") and outdim != indim:
|
269 |
+
LOG.info(f"Overwriting outdim ({outdim}) to be indim ({indim})")
|
270 |
+
outdim = indim
|
271 |
+
|
272 |
+
if init == "id":
|
273 |
+
old_hidden_dim = hidden_dim
|
274 |
+
if hidden_dim < indim * 2:
|
275 |
+
hidden_dim = indim * 2
|
276 |
+
|
277 |
+
if hidden_dim % indim != 0:
|
278 |
+
hidden_dim += hidden_dim % indim
|
279 |
+
|
280 |
+
if old_hidden_dim != hidden_dim:
|
281 |
+
LOG.info(
|
282 |
+
f"Overwriting hidden dim ({old_hidden_dim}) to be {hidden_dim}"
|
283 |
+
)
|
284 |
+
|
285 |
+
if init == "id_alpha":
|
286 |
+
self.alpha = nn.Parameter(torch.zeros(1, outdim))
|
287 |
+
|
288 |
+
dims = [indim] + [hidden_dim] * n_hidden + [outdim]
|
289 |
+
LOG.info(f"Building ({init}) MLP: {dims} (rank {rank})")
|
290 |
+
|
291 |
+
layers = []
|
292 |
+
for idx, (ind, outd) in enumerate(zip(dims[:-1], dims[1:])):
|
293 |
+
if rank is None:
|
294 |
+
layers.append(nn.Linear(ind, outd))
|
295 |
+
else:
|
296 |
+
layers.append(LRLinear(ind, outd, rank=rank))
|
297 |
+
if idx < n_hidden:
|
298 |
+
layers.append(self.act)
|
299 |
+
|
300 |
+
if rank is None:
|
301 |
+
if init == "id":
|
302 |
+
if n_hidden > 0:
|
303 |
+
layers[0].weight.data = torch.eye(indim).repeat(
|
304 |
+
hidden_dim // indim, 1
|
305 |
+
)
|
306 |
+
layers[0].weight.data[hidden_dim // 2:] *= -1
|
307 |
+
layers[-1].weight.data = torch.eye(outdim).repeat(
|
308 |
+
1, hidden_dim // outdim
|
309 |
+
)
|
310 |
+
layers[-1].weight.data[:, hidden_dim // 2:] *= -1
|
311 |
+
layers[-1].weight.data /= (hidden_dim // indim) / 2.0
|
312 |
+
|
313 |
+
for layer in layers:
|
314 |
+
if isinstance(layer, nn.Linear):
|
315 |
+
if init == "ortho":
|
316 |
+
nn.init.orthogonal_(layer.weight)
|
317 |
+
elif init == "id":
|
318 |
+
if layer.weight.shape[0] == layer.weight.shape[1]:
|
319 |
+
layer.weight.data = torch.eye(hidden_dim)
|
320 |
+
else:
|
321 |
+
gain = 3 ** 0.5 if (layer is layers[-1]) else 1.0
|
322 |
+
nn.init.xavier_uniform_(layer.weight, gain=gain)
|
323 |
+
|
324 |
+
layer.bias.data[:] = 0
|
325 |
+
|
326 |
+
layers[-1].bias = None
|
327 |
+
self.mlp = nn.Sequential(*layers)
|
328 |
+
|
329 |
+
def forward(self, x):
|
330 |
+
if self.init == "id_alpha":
|
331 |
+
return x + self.alpha * self.mlp(x)
|
332 |
+
else:
|
333 |
+
return self.mlp(x)
|
334 |
+
|
335 |
+
|
336 |
+
if __name__ == "__main__":
|
337 |
+
logging.basicConfig(
|
338 |
+
format="%(asctime)s - %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
|
339 |
+
level=logging.INFO,
|
340 |
+
)
|
341 |
+
m0 = MLP(1000, 1000, 1500, 3)
|
342 |
+
m1 = MLP(1000, 1000, 1500, 3, init="id")
|
343 |
+
m2 = MLP(1000, 1000, 1500, 3, init="id_alpha")
|
344 |
+
m3 = MLP(1000, 1000, 1500, 3, init="ortho", act="learned")
|
345 |
+
|
346 |
+
x = 0.01 * torch.randn(999, 1000)
|
347 |
+
|
348 |
+
y0 = m0(x)
|
349 |
+
y1 = m1(x)
|
350 |
+
y2 = m2(x)
|
351 |
+
y3 = m3(x)
|
352 |
+
|
353 |
+
print("y0", (y0 - x).abs().max())
|
354 |
+
print("y1", (y1 - x).abs().max())
|
355 |
+
print("y2", (y2 - x).abs().max())
|
356 |
+
print("y3", (y3 - x).abs().max())
|
357 |
+
|
358 |
+
assert not torch.allclose(y0, x)
|
359 |
+
assert torch.allclose(y1, x)
|
360 |
+
assert torch.allclose(y2, x)
|
361 |
+
assert not torch.allclose(y3, x)
|
362 |
+
import pdb; pdb.set_trace() # fmt: skip
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
allennlp
|
2 |
+
git+https://github.com/eric-mitchell/higher@master # For in-place functional models
|
3 |
+
pandas
|
4 |
+
streamlit
|
5 |
+
torch
|
6 |
+
transformers
|
utils.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import typing
|
3 |
+
import numpy as np
|
4 |
+
import struct
|
5 |
+
import os
|
6 |
+
import getpass
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from collections import defaultdict
|
11 |
+
import math
|
12 |
+
|
13 |
+
|
14 |
+
LOG = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
def masked_mean(values, mask):
|
17 |
+
assert mask.dtype == torch.bool
|
18 |
+
assert values.shape == mask.shape
|
19 |
+
return (values * mask.float()).sum() / mask.sum().float()
|
20 |
+
|
21 |
+
|
22 |
+
def mask_hf_labels(labels, null_token=0):
|
23 |
+
valid_mask = labels != -100
|
24 |
+
valid_labels = labels.masked_fill(~valid_mask, null_token)
|
25 |
+
return valid_mask, valid_labels
|
26 |
+
|
27 |
+
|
28 |
+
def gather_log_probs(logits, labels):
|
29 |
+
assert labels.dim() == logits.dim() - 1
|
30 |
+
assert labels.shape == logits.shape[:-1]
|
31 |
+
return logits.log_softmax(-1).gather(-1, labels.unsqueeze(-1)).squeeze(-1)
|
32 |
+
|
33 |
+
|
34 |
+
def off_diagonal(mat):
|
35 |
+
assert mat.dim() == 2
|
36 |
+
# assert mat.shape[0] == mat.shape[1]
|
37 |
+
|
38 |
+
mask = ~torch.eye(max(mat.shape), dtype=torch.bool)
|
39 |
+
mask = mask[:mat.shape[0], :mat.shape[1]]
|
40 |
+
off_d = mat[mask]
|
41 |
+
|
42 |
+
assert off_d.numel() == mat.shape[0] * mat.shape[1] - min(mat.shape)
|
43 |
+
|
44 |
+
return off_d
|
45 |
+
|
46 |
+
|
47 |
+
def set_dropout(model, p):
|
48 |
+
if p is not None:
|
49 |
+
n_reset = 0
|
50 |
+
for m in model.modules():
|
51 |
+
if isinstance(m, nn.Dropout):
|
52 |
+
m.p = p
|
53 |
+
n_reset += 1
|
54 |
+
|
55 |
+
if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout
|
56 |
+
if isinstance(m.dropout, float):
|
57 |
+
m.dropout = p
|
58 |
+
n_reset += 1
|
59 |
+
|
60 |
+
if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout
|
61 |
+
if isinstance(m.activation_dropout, float):
|
62 |
+
m.activation_dropout = p
|
63 |
+
n_reset += 1
|
64 |
+
|
65 |
+
LOG.info(f"Set {n_reset} dropout modules to p={p}")
|
66 |
+
|
67 |
+
|
68 |
+
def _inner_params(named_parameters, inner_names):
|
69 |
+
param_dict = dict(named_parameters)
|
70 |
+
return [(n, param_dict[n]) for n in inner_names]
|
71 |
+
|
72 |
+
|
73 |
+
def shift_targets(config):
|
74 |
+
return "t5" not in config.model.name.lower() and "blender" not in config.model.name.lower()
|
75 |
+
|
76 |
+
|
77 |
+
# https://stackoverflow.com/questions/32871539/integer-factorization-in-python
|
78 |
+
def factorization(n):
|
79 |
+
return [(i, n // i) for i in range(1, int(n**0.5) + 1) if n % i == 0]
|
80 |
+
|
81 |
+
|
82 |
+
def scr():
|
83 |
+
if os.path.exists("/scr-ssd"):
|
84 |
+
scr_dir = "/scr-ssd/" + getpass.getuser()
|
85 |
+
else:
|
86 |
+
scr_dir = "/scr/" + getpass.getuser()
|
87 |
+
|
88 |
+
if not os.path.exists(scr_dir):
|
89 |
+
os.makedirs(scr_dir)
|
90 |
+
|
91 |
+
return scr_dir
|
92 |
+
|
93 |
+
|
94 |
+
def uuid(digits=4):
|
95 |
+
if not hasattr(uuid, "uuid_value"):
|
96 |
+
uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits)
|
97 |
+
|
98 |
+
return uuid.uuid_value
|
99 |
+
|
100 |
+
|
101 |
+
def formatted_timestamp(time=None):
|
102 |
+
if time is None:
|
103 |
+
time = datetime.datetime.now()
|
104 |
+
return time.strftime("%d/%m/%Y-%H:%M:%S/%f")
|
105 |
+
|
106 |
+
|
107 |
+
def time_delta_seconds(start, finish=None):
|
108 |
+
assert type(start) == str
|
109 |
+
|
110 |
+
t1 = datetime.datetime.strptime(start, "%d/%m/%Y-%H:%M:%S/%f")
|
111 |
+
if finish is not None:
|
112 |
+
assert type(finish) == str
|
113 |
+
t2 = datetime.datetime.strptime(finish, "%d/%m/%Y-%H:%M:%S/%f")
|
114 |
+
else:
|
115 |
+
t2 = datetime.datetime.now()
|
116 |
+
|
117 |
+
return (t2 - t1).total_seconds()
|
118 |
+
|
119 |
+
|
120 |
+
def dict_to(d, device):
|
121 |
+
new_dict = {}
|
122 |
+
for k, v in d.items():
|
123 |
+
if isinstance(v, torch.Tensor):
|
124 |
+
new_dict[k] = v.to(device)
|
125 |
+
elif isinstance(v, dict):
|
126 |
+
new_dict[k] = dict_to(v, device)
|
127 |
+
else:
|
128 |
+
new_dict[k] = v
|
129 |
+
|
130 |
+
return new_dict
|
131 |
+
|
132 |
+
|
133 |
+
def safe_backward(loss, parameters, accumulate=1, allow_unused=False, backward=False):
|
134 |
+
if backward:
|
135 |
+
(loss / accumulate).backward()
|
136 |
+
else:
|
137 |
+
parameters = list(parameters) # Capture the generator output
|
138 |
+
grads = torch.autograd.grad(loss, parameters, allow_unused=allow_unused)
|
139 |
+
nan, inf = False, False
|
140 |
+
for g in grads:
|
141 |
+
if g is not None:
|
142 |
+
nan |= g.isnan().any().item()
|
143 |
+
inf |= g.isinf().any().item()
|
144 |
+
|
145 |
+
if not (nan or inf):
|
146 |
+
for p, g in zip(parameters, grads):
|
147 |
+
if g is None:
|
148 |
+
continue
|
149 |
+
|
150 |
+
if p.grad is None:
|
151 |
+
p.grad = g / accumulate
|
152 |
+
else:
|
153 |
+
p.grad += g / accumulate
|
154 |
+
else:
|
155 |
+
LOG.info(f"Skipping grad accumulation because inf: {inf} nan: {nan}")
|
156 |
+
|
157 |
+
|
158 |
+
def _logits(x):
|
159 |
+
return x if not hasattr(x, "logits") else x.logits
|
160 |
+
|
161 |
+
|
162 |
+
def _last_encoder_state(x):
|
163 |
+
if hasattr(x, "encoder_last_hidden_state"):
|
164 |
+
return x.encoder_last_hidden_state
|
165 |
+
else:
|
166 |
+
return x.hidden_states[-1]
|
167 |
+
|
168 |
+
|
169 |
+
def load_archive(path):
|
170 |
+
import torch
|
171 |
+
|
172 |
+
if not os.path.exists(path):
|
173 |
+
# We've not passed an explicit path, but a part of the filename
|
174 |
+
wd = '/iris/u/clin/code/efk/'
|
175 |
+
directories = ["outputs", "multirun"]
|
176 |
+
matches = []
|
177 |
+
for d in directories:
|
178 |
+
search = os.path.join(wd, d)
|
179 |
+
for run_dir in os.listdir(search):
|
180 |
+
if path in run_dir:
|
181 |
+
matches.append(os.path.join(search, run_dir))
|
182 |
+
assert len(matches) == 1, f">1 matches for search {path}; specify exact path"
|
183 |
+
|
184 |
+
full_run_dir = matches[0]
|
185 |
+
if "0" in os.listdir(full_run_dir):
|
186 |
+
full_run_dir = os.path.join(full_run_dir, "0")
|
187 |
+
models_dir = os.path.join(full_run_dir, "models")
|
188 |
+
models = os.listdir(models_dir)
|
189 |
+
non_bk = [m for m in models if not m.endswith(".bk")]
|
190 |
+
assert (
|
191 |
+
len(non_bk) == 1
|
192 |
+
), f"Expected a single model in {models_dir}, got {len(non_bk)}"
|
193 |
+
path = os.path.join(models_dir, non_bk[0])
|
194 |
+
|
195 |
+
LOG.info(f"Loading checkpoint from {path}")
|
196 |
+
archive = torch.load(path, map_location="cpu")
|
197 |
+
LOG.info("Load complete.")
|
198 |
+
|
199 |
+
return archive, path
|
200 |
+
|
201 |
+
|
202 |
+
def flatten_dict(d):
|
203 |
+
to_process = list(d.items())
|
204 |
+
output = {}
|
205 |
+
while len(to_process):
|
206 |
+
k, v = to_process.pop()
|
207 |
+
if isinstance(v, typing.MutableMapping):
|
208 |
+
to_process.extend([(f"{k}.{k_}", v_) for (k_, v_) in v.items()])
|
209 |
+
else:
|
210 |
+
assert k not in output.keys(), "Somehow ended up with duplicate keys"
|
211 |
+
output[k] = v
|
212 |
+
|
213 |
+
return output
|
214 |
+
|
215 |
+
|
216 |
+
def add_padding(tokenizer, model):
|
217 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
218 |
+
model.resize_token_embeddings(len(tokenizer))
|
219 |
+
model.transformer.wte.weight.data[-1] = model.transformer.wte.weight.data.mean(0)
|
220 |
+
|
221 |
+
|
222 |
+
def add_sep(tokenizer, model):
|
223 |
+
tokenizer.add_special_tokens({'sep_token': '[SEP]'})
|
224 |
+
# model.resize_token_embeddings(len(tokenizer))
|
225 |
+
# model.lm_head.weight.data[-1, :] = model.lm_head.weight.data.mean(0)
|
226 |
+
|
227 |
+
|
228 |
+
class EarlyStopper:
|
229 |
+
def __init__(self, patience: int, key: str, minimize: bool = False):
|
230 |
+
self.best_value = 1e9 if minimize else -1e9
|
231 |
+
self.best_iter = 0
|
232 |
+
self.current_iter = 0
|
233 |
+
self.key = key
|
234 |
+
self.patience = patience
|
235 |
+
self.minimize = minimize
|
236 |
+
self._stop = False
|
237 |
+
|
238 |
+
def update(self, idx, stats):
|
239 |
+
assert self.key in stats, f"'{self.key}' not in stats dict"
|
240 |
+
value = stats[self.key]
|
241 |
+
new_best = value < self.best_value if self.minimize else value > self.best_value
|
242 |
+
if new_best:
|
243 |
+
self.best_value = value
|
244 |
+
self.best_iter = idx
|
245 |
+
|
246 |
+
self.current_iter = idx
|
247 |
+
return new_best
|
248 |
+
|
249 |
+
def should_stop(self):
|
250 |
+
self._stop |= self.current_iter - self.best_iter >= self.patience
|
251 |
+
return self._stop
|
252 |
+
|
253 |
+
|
254 |
+
class RunningStatAverager:
|
255 |
+
def __init__(self, suffix="", exclude=["grad/"], compute_ppl: bool = True):
|
256 |
+
self.underlying = None
|
257 |
+
self.suffix = suffix
|
258 |
+
self.exclude = exclude
|
259 |
+
self.compute_ppl = compute_ppl
|
260 |
+
|
261 |
+
self.reset()
|
262 |
+
|
263 |
+
def add(self, d: dict):
|
264 |
+
for k, v in d.items():
|
265 |
+
if not any([k.startswith(prefix) for prefix in self.exclude]):
|
266 |
+
if len(self.suffix):
|
267 |
+
self.underlying[f"{k}_{self.suffix}"].append(v)
|
268 |
+
else:
|
269 |
+
self.underlying[k].append(v)
|
270 |
+
|
271 |
+
def average(self):
|
272 |
+
average = {}
|
273 |
+
for k, v in self.underlying.items():
|
274 |
+
if not k.startswith("nll/"):
|
275 |
+
average[k] = sum(v) / len(v)
|
276 |
+
else:
|
277 |
+
assert len(k.split("/")) == 2, f"Invalid key {k}"
|
278 |
+
name = k.split("/")[1]
|
279 |
+
token_counts = self.underlying[f"n_tokens/{name}"]
|
280 |
+
total_nll = sum([nll * c for nll, c in zip(v, token_counts)])
|
281 |
+
average[k] = total_nll / sum(token_counts)
|
282 |
+
if self.compute_ppl:
|
283 |
+
average[f"perplexity/{name}"] = math.e ** average[k]
|
284 |
+
|
285 |
+
return {k: v if not isinstance(v, torch.Tensor) else v.item() for k, v in average.items()}
|
286 |
+
|
287 |
+
def reset(self):
|
288 |
+
self.underlying = defaultdict(list)
|
289 |
+
|
290 |
+
|
291 |
+
class EditBatchSampler:
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
n,
|
295 |
+
memorize_mode=False,
|
296 |
+
loc_disjoint=True,
|
297 |
+
seed=0,
|
298 |
+
hard_neg=False,
|
299 |
+
hard_neg_prob=1.0,
|
300 |
+
loc_distr_matrix=None,
|
301 |
+
loc_idx_matrix=None,
|
302 |
+
keep_probs=None,
|
303 |
+
mutex=None
|
304 |
+
):
|
305 |
+
self.memorize_mode = memorize_mode
|
306 |
+
self.n = n
|
307 |
+
self.loc_disjoint = loc_disjoint
|
308 |
+
self.rng = np.random.default_rng(seed)
|
309 |
+
self.hard_neg = hard_neg
|
310 |
+
self.hard_neg_prob = hard_neg_prob
|
311 |
+
self.loc_probs = loc_distr_matrix
|
312 |
+
self.loc_idxs = loc_idx_matrix
|
313 |
+
self.keep_probs = np.array(keep_probs)[:self.n] if keep_probs is not None else None
|
314 |
+
self.mutex = mutex[:self.n] if mutex is not None else None
|
315 |
+
self._init()
|
316 |
+
|
317 |
+
def _init(self):
|
318 |
+
idxs = np.arange(self.n)
|
319 |
+
if self.keep_probs is not None:
|
320 |
+
sample = self.rng.binomial(1, self.keep_probs).astype(np.bool)
|
321 |
+
idxs = idxs[sample]
|
322 |
+
|
323 |
+
self.perm = self.rng.permutation(idxs)
|
324 |
+
self.edit_position = 0
|
325 |
+
|
326 |
+
def get_edit_idxs(self, batch_size):
|
327 |
+
if self.mutex is None:
|
328 |
+
idxs = set([int(idx) for idx in self.perm[self.edit_position: self.edit_position + batch_size]])
|
329 |
+
self.edit_position += batch_size
|
330 |
+
else:
|
331 |
+
mutexes = []
|
332 |
+
idxs = []
|
333 |
+
|
334 |
+
def notin(x, mutexes):
|
335 |
+
for m in mutexes:
|
336 |
+
if x in m or m in x:
|
337 |
+
return False
|
338 |
+
return True
|
339 |
+
while len(idxs) < batch_size:
|
340 |
+
new_idx = self.perm[self.edit_position]
|
341 |
+
if notin(self.mutex[new_idx], mutexes):
|
342 |
+
mutexes.append(self.mutex[new_idx])
|
343 |
+
idxs.append(int(new_idx))
|
344 |
+
self.edit_position += 1
|
345 |
+
if self.edit_position == self.perm.shape[0]:
|
346 |
+
return None
|
347 |
+
|
348 |
+
idxs = set(idxs)
|
349 |
+
|
350 |
+
return idxs
|
351 |
+
|
352 |
+
def sample(self, batch_size, return_hard_flag=False):
|
353 |
+
if self.memorize_mode:
|
354 |
+
return list(range(batch_size)), list(range(batch_size, batch_size * 2))
|
355 |
+
|
356 |
+
if self.edit_position + batch_size >= self.perm.shape[0]:
|
357 |
+
self._init() # Re-start if we end with a partially-sized batch
|
358 |
+
|
359 |
+
edit_idxs = self.get_edit_idxs(batch_size)
|
360 |
+
if edit_idxs is None:
|
361 |
+
self._init()
|
362 |
+
edit_idxs = self.get_edit_idxs(batch_size)
|
363 |
+
if edit_idxs is None:
|
364 |
+
raise RuntimeError(f"No valid batches of size {batch_size} exist!")
|
365 |
+
|
366 |
+
if self.hard_neg:
|
367 |
+
assert self.loc_probs is not None, "hard_neg is on, but don't have distance matrix!"
|
368 |
+
|
369 |
+
def get_loc_idxs():
|
370 |
+
if self.hard_neg and self.rng.uniform() < self.hard_neg_prob:
|
371 |
+
return [int(self.rng.choice(self.loc_idxs[idx], p=self.loc_probs[idx])) for idx in edit_idxs], True
|
372 |
+
else:
|
373 |
+
# Use deterministic implementation in case edit batches are large
|
374 |
+
non_edit_idxs = list(set(range(self.n)) - set(edit_idxs))
|
375 |
+
return [int(idx) for idx in self.rng.choice(non_edit_idxs, batch_size)], False
|
376 |
+
|
377 |
+
loc_idxs, hard = get_loc_idxs()
|
378 |
+
if self.loc_disjoint:
|
379 |
+
steps = 0
|
380 |
+
while len(edit_idxs.intersection(set(loc_idxs))) > 0:
|
381 |
+
loc_idxs, hard = get_loc_idxs()
|
382 |
+
steps += 1
|
383 |
+
if steps > 100:
|
384 |
+
raise RuntimeError("Can't find disjoint loc_idxs and edit_idxs!")
|
385 |
+
|
386 |
+
if return_hard_flag:
|
387 |
+
return list(edit_idxs), loc_idxs, hard
|
388 |
+
else:
|
389 |
+
return list(edit_idxs), loc_idxs
|
390 |
+
|
391 |
+
|
392 |
+
def parent_module(model, pname):
|
393 |
+
comps = pname.split('.')
|
394 |
+
parent = model
|
395 |
+
for comp in comps[:-1]:
|
396 |
+
if hasattr(parent, comp):
|
397 |
+
parent = getattr(parent, comp)
|
398 |
+
elif comp.isdigit():
|
399 |
+
parent = parent[int(comp)]
|
400 |
+
else:
|
401 |
+
raise RuntimeError(f"Couldn't find child module {comp}")
|
402 |
+
assert hasattr(parent, comps[-1])
|
403 |
+
return parent
|
404 |
+
|
405 |
+
|
406 |
+
def build_distr_matrix(edit_qs, config, loc_qs=None, slice_size=1000):
|
407 |
+
n = len(edit_qs)
|
408 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
409 |
+
|
410 |
+
num_neighbors = config.data.hard_neg_neighbors
|
411 |
+
num_exclude = config.data.hard_neg_exclude
|
412 |
+
temp = config.data.hard_neg_temp
|
413 |
+
|
414 |
+
from sentence_transformers import SentenceTransformer
|
415 |
+
from sentence_transformers.util import pytorch_cos_sim
|
416 |
+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2', cache_folder=scr()).to(device)
|
417 |
+
|
418 |
+
ind_matrix = torch.zeros((n, num_neighbors - num_exclude), dtype=torch.long)
|
419 |
+
distr_matrix = torch.full((n, num_neighbors - num_exclude), float('nan'))
|
420 |
+
edit_encodings = torch.FloatTensor(embedding_model.encode(edit_qs, batch_size=256)).to(device)
|
421 |
+
|
422 |
+
# If loc_qs is None then build the similarity matrix between edit_qs and itself
|
423 |
+
loc_encodings = edit_encodings if loc_qs is None else embedding_model.encode(loc_qs, batch_size=256)
|
424 |
+
if isinstance(loc_encodings, np.ndarray):
|
425 |
+
loc_encodings = torch.FloatTensor(loc_encodings).to(device)
|
426 |
+
|
427 |
+
for idx in range(0, n, slice_size):
|
428 |
+
end_idx = idx + slice_size if idx + slice_size <= n else n
|
429 |
+
slice_encodings = edit_encodings[idx:end_idx]
|
430 |
+
sim_rows = pytorch_cos_sim(slice_encodings, loc_encodings)
|
431 |
+
indices = sim_rows.topk(num_neighbors, -1).indices[:, num_exclude:]
|
432 |
+
ind_matrix[idx:end_idx] = indices.cpu()
|
433 |
+
distr_matrix[idx:end_idx] = sim_rows.gather(-1, indices).mul(temp).exp().cpu()
|
434 |
+
|
435 |
+
assert not torch.isnan(distr_matrix).any()
|
436 |
+
|
437 |
+
LOG.info(f"Built hard negative distribution matrix of size {distr_matrix.shape}")
|
438 |
+
distr_matrix = distr_matrix.numpy()
|
439 |
+
distr_matrix = distr_matrix / distr_matrix.sum(-1, keepdims=True)
|
440 |
+
return distr_matrix, ind_matrix.numpy()
|
441 |
+
|