model-editing / algs /mend.py
Charles Lin
Add algorithms from efk codebase
e56055d
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import transformers
import higher
import logging
from higher.patch import monkeypatch as make_functional
from collections import defaultdict
from editable_model import EditableModel
from hooks import hook_model
import nn as local_nn
from utils import _logits, _inner_params
LOG = logging.getLogger(__name__)
def update_counter(x, m, s, k):
new_m = m + (x - m) / k
new_s = s + (x - m) * (x - new_m)
return new_m, new_s
class GradientTransform(nn.Module):
def __init__(self, x_dim: int, delta_dim: int, cfg, n_modes = None):
super().__init__()
self.x_dim = x_dim
self.delta_dim = delta_dim
self.cfg = cfg
if cfg.combine and (cfg.one_sided or cfg.x_only or cfg.delta_only):
raise ValueError("cfg.combine cannot be used with one-sided MEND variants")
self.norm_init = False
self.register_buffer("u_mean", torch.full((x_dim,), float("nan")))
self.register_buffer("v_mean", torch.full((delta_dim,), float("nan")))
self.register_buffer("u_std", torch.full((x_dim,), float("nan")))
self.register_buffer("v_std", torch.full((delta_dim,), float("nan")))
self.register_buffer("u_s", torch.full((x_dim,), float("nan")))
self.register_buffer("v_s", torch.full((delta_dim,), float("nan")))
self.register_buffer("k", torch.full((1,), float("nan")))
MlpClass = getattr(local_nn, cfg.mlp_class)
LOG.info(f"Building Gradient Transform with MLP class {MlpClass}")
def delta_net():
return MlpClass(delta_dim, delta_dim, delta_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)
def x_net():
return MlpClass(x_dim, x_dim, x_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)
def combined_net():
return MlpClass(delta_dim + x_dim, delta_dim + x_dim, (delta_dim + x_dim) * 2,
cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)
def ID():
return lambda x, mode=None: x
if cfg.combine:
self.mlp = combined_net()
elif cfg.one_sided:
if x_dim > delta_dim:
self.mlp1, self.mlp2 = ID(), delta_net()
else:
self.mlp1, self.mlp2 = x_net(), ID()
elif cfg.x_only:
self.mlp1, self.mlp2 = x_net(), ID()
elif cfg.delta_only:
self.mlp1, self.mlp2 = ID(), delta_net()
else:
self.mlp1, self.mlp2 = x_net(), delta_net()
def forward(self, u, v, param_idx=None):
u, v = u.to(torch.float32), v.to(torch.float32)
u_ = u.view(-1, u.shape[-1])
v_ = v.view(-1, v.shape[-1])
nz_mask = (u_ != 0).any(-1) * (v_ != 0).any(-1) # Skip batch elements with zero grad
u_ = u_[nz_mask]
v_ = v_[nz_mask]
if self.training:
for idx in range(u_.shape[0]):
if not self.norm_init:
self.u_mean = u_[idx].clone().detach()
self.v_mean = v_[idx].clone().detach()
self.u_s.zero_()
self.v_s.zero_()
self.k[:] = 1
self.norm_init = True
else:
self.k += 1
self.u_mean, self.u_s = update_counter(u_[idx], self.u_mean, self.u_s, self.k)
self.v_mean, self.v_s = update_counter(v_[idx], self.v_mean, self.v_s, self.k)
if self.k < 2:
raise RuntimeError(f"Can't perform normalization with only {self.k} samples so far")
self.u_std = (self.u_s / (self.k - 1)) ** 0.5
self.v_std = (self.v_s / (self.k - 1)) ** 0.5
if self.cfg.norm:
u_input = (u_ - self.u_mean) / (self.u_std + 1e-7)
v_input = (v_ - self.v_mean) / (self.v_std + 1e-7)
else:
u_input = u_
v_input = v_
if self.cfg.combine:
output = self.mlp(torch.cat((u_input, v_input), -1), mode=param_idx)
out1, out2 = output.split([u.shape[-1], v.shape[-1]], -1)
return out1, out2
else:
return self.mlp1(u_input, mode=param_idx), self.mlp2(v_input, mode=param_idx)
class MEND(EditableModel):
def get_shape(self, p):
# We need to (annoyingly) flip the shapes since OpenAI gpt2 uses convs instead of linear
return p.shape if isinstance(self.model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0])
def __init__(self, model, config, model_constructor, gtn=None, edit_lrs=None):
super().__init__(model, config, model_constructor)
if edit_lrs is None:
edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params)))
self.edit_lrs = edit_lrs
if not hasattr(self.model, "handles"):
hook_model(self.model, self.config.model.inner_params)
LOG.info(f"Hooked {len(self.model.handles)//2} modules")
if config.gtn.shared:
shape_dict = defaultdict(list)
for n, p in _inner_params(model.named_parameters(), self.config.model.inner_params):
shape_dict[self.get_shape(p)].append(n)
self.shape_dict = shape_dict
if gtn is None:
if not config.gtn.shared:
self.gtn = nn.ModuleDict({
n.replace(".", "#"): GradientTransform(*self.get_shape(p), config.gtn)
for (n, p) in _inner_params(model.named_parameters(), self.config.model.inner_params)
})
else:
self.gtn = nn.ModuleDict({
str(tuple(s)): GradientTransform(*s, config.gtn, len(shape_dict[s]))
for s in shape_dict.keys()
})
else:
self.gtn = gtn
def state_dict(self, destination=None, prefix="", keep_vars=False):
state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict
model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params
for k in model_keys:
del state_dict[f"model.{k}"]
state_dict["model_config"] = self.model.config # Include model config
return state_dict
def load_state_dict(self, state_dict, strict: bool = True):
config = state_dict["model_config"]
del state_dict["model_config"]
if config != self.model.config:
LOG.info("Loaded model config doesn't match current model config.")
LOG.info(f"Loaded: {config}")
LOG.info(f"Current: {self.model.config}")
res = super().load_state_dict(state_dict, False)
# We should only have missing keys for the model, and no unexpected keys
assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model."
assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys"
return res
def outer_parameters(self, grouped=False):
if grouped:
return [
dict(params=list(self.gtn.parameters()), lr=self.config.lr),
dict(params=[self.edit_lrs], lr=self.config.lr_lr)
]
else:
return list(self.gtn.parameters()) + [self.edit_lrs]
def edit(self, batch, condition=None, detach_history=False):
outputs = _logits(self.model(**batch))
loss = self.edit_loss_fn(outputs, batch["labels"])["nll"]
names = set([n for n, p in self.model.named_parameters()])
pset = set(self.config.model.inner_params)
for p in pset:
assert p in names, f"inner param {p} not in model"
loss.backward()
if self.config.gtn.shared:
param_idx = lambda n, p: self.shape_dict[self.get_shape(p)].index(n) if self.config.gtn.shared else None # noqa: E731
transformed_factors = {
n: self.gtn[str(tuple(self.get_shape(p)))](p.__x__, p.__delta__, param_idx(n, p))
for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params)
}
else:
transformed_factors = {
n: self.gtn[n.replace(".", "#")](p.__x__, p.__delta__)
for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params)
}
# Should be bi,bj->ji for nn.Linear, but [annoying] GPT2 uses Conv1d instead...
if isinstance(self.model, transformers.GPT2LMHeadModel):
targ = "ij"
else:
targ = "ji"
mean_grads = {
n: torch.einsum(f"bi,bj->{targ}", x, delta)
for n, (x, delta) in transformed_factors.items()
}
info_dict = {}
idx = 0
for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params):
info_dict[f"grad/true_mag{idx}"] = p.grad.norm(2).item()
info_dict[f"grad/pseudo_mag{idx}"] = mean_grads[n].norm(2).item()
info_dict[f"grad/true_std{idx}"] = p.grad.std().item()
info_dict[f"grad/pseudo_std{idx}"] = mean_grads[n].std().item()
info_dict[f"grad/diff{idx}"] = (p.grad - mean_grads[n]).norm(2).item()
info_dict[f"grad/cos{idx}"] = F.cosine_similarity(p.grad.reshape(-1), mean_grads[n].reshape(-1), dim=0).item()
idx += 1
self.model.zero_grad()
assert len(self.edit_lrs) == len(list(mean_grads.items()))
updates = {n: lr * g for lr, (n, g) in zip(self.edit_lrs, mean_grads.items())}
edited_model = self.model
if not isinstance(edited_model, higher.patch._MonkeyPatchBase):
edited_model = make_functional(edited_model, in_place=True)
new_params = []
for n, p in edited_model.named_parameters():
if n in pset:
if self.config.gtn.descent:
new_params.append(p - updates[n])
else:
new_params.append(p + updates[n])
else:
new_params.append(p)
edited_model.update_params(new_params)
if detach_history:
new_model = self.model_constructor()
new_model.load_state_dict(edited_model.state_dict())
edited_model = new_model
return MEND(edited_model, self.config, self.model_constructor, self.gtn, edit_lrs=self.edit_lrs), info_dict
if __name__ == '__main__':
import types
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
config = types.SimpleNamespace()
config.model.inner_params = [
"transformer.h.9.mlp.c_fc.weight",
"transformer.h.9.mlp.c_proj.weight",
"transformer.h.10.mlp.c_fc.weight",
"transformer.h.10.mlp.c_proj.weight",
"transformer.h.11.mlp.c_fc.weight",
"transformer.h.11.mlp.c_proj.weight",
]
config.edit_lr = 0.0001
config.gtn = types.SimpleNamespace()
config.gtn.n_hidden = 1
config.gtn = config.gtn.__dict__
gtn = MEND(model, config, lambda: copy.deepcopy(model)).cuda()
# torch.save(gtn.state_dict(), "test_state.pt")
import pdb; pdb.set_trace()
gtn.load_state_dict(torch.load("test_state.pt"))
x = torch.arange(20).view(1, 20).cuda() + 1000
orig_logits = gtn(x)
edited = gtn.edit(x, masks=torch.ones_like(x), labels=x)
post_logits = gtn(x)
assert torch.allclose(orig_logits, post_logits)
orig_param = [p for (n, p) in gtn.model.named_parameters() if n == config.model.inner_params[-1]][0]
edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]
LOG.info((orig_param - edited_param).abs().max())
edited.eval()
LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"])
edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x)
LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss)