Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import copy | |
import transformers | |
import higher | |
import logging | |
from higher.patch import monkeypatch as make_functional | |
from collections import defaultdict | |
from editable_model import EditableModel | |
from hooks import hook_model | |
import nn as local_nn | |
from utils import _logits, _inner_params | |
LOG = logging.getLogger(__name__) | |
def update_counter(x, m, s, k): | |
new_m = m + (x - m) / k | |
new_s = s + (x - m) * (x - new_m) | |
return new_m, new_s | |
class GradientTransform(nn.Module): | |
def __init__(self, x_dim: int, delta_dim: int, cfg, n_modes = None): | |
super().__init__() | |
self.x_dim = x_dim | |
self.delta_dim = delta_dim | |
self.cfg = cfg | |
if cfg.combine and (cfg.one_sided or cfg.x_only or cfg.delta_only): | |
raise ValueError("cfg.combine cannot be used with one-sided MEND variants") | |
self.norm_init = False | |
self.register_buffer("u_mean", torch.full((x_dim,), float("nan"))) | |
self.register_buffer("v_mean", torch.full((delta_dim,), float("nan"))) | |
self.register_buffer("u_std", torch.full((x_dim,), float("nan"))) | |
self.register_buffer("v_std", torch.full((delta_dim,), float("nan"))) | |
self.register_buffer("u_s", torch.full((x_dim,), float("nan"))) | |
self.register_buffer("v_s", torch.full((delta_dim,), float("nan"))) | |
self.register_buffer("k", torch.full((1,), float("nan"))) | |
MlpClass = getattr(local_nn, cfg.mlp_class) | |
LOG.info(f"Building Gradient Transform with MLP class {MlpClass}") | |
def delta_net(): | |
return MlpClass(delta_dim, delta_dim, delta_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes) | |
def x_net(): | |
return MlpClass(x_dim, x_dim, x_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes) | |
def combined_net(): | |
return MlpClass(delta_dim + x_dim, delta_dim + x_dim, (delta_dim + x_dim) * 2, | |
cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes) | |
def ID(): | |
return lambda x, mode=None: x | |
if cfg.combine: | |
self.mlp = combined_net() | |
elif cfg.one_sided: | |
if x_dim > delta_dim: | |
self.mlp1, self.mlp2 = ID(), delta_net() | |
else: | |
self.mlp1, self.mlp2 = x_net(), ID() | |
elif cfg.x_only: | |
self.mlp1, self.mlp2 = x_net(), ID() | |
elif cfg.delta_only: | |
self.mlp1, self.mlp2 = ID(), delta_net() | |
else: | |
self.mlp1, self.mlp2 = x_net(), delta_net() | |
def forward(self, u, v, param_idx=None): | |
u, v = u.to(torch.float32), v.to(torch.float32) | |
u_ = u.view(-1, u.shape[-1]) | |
v_ = v.view(-1, v.shape[-1]) | |
nz_mask = (u_ != 0).any(-1) * (v_ != 0).any(-1) # Skip batch elements with zero grad | |
u_ = u_[nz_mask] | |
v_ = v_[nz_mask] | |
if self.training: | |
for idx in range(u_.shape[0]): | |
if not self.norm_init: | |
self.u_mean = u_[idx].clone().detach() | |
self.v_mean = v_[idx].clone().detach() | |
self.u_s.zero_() | |
self.v_s.zero_() | |
self.k[:] = 1 | |
self.norm_init = True | |
else: | |
self.k += 1 | |
self.u_mean, self.u_s = update_counter(u_[idx], self.u_mean, self.u_s, self.k) | |
self.v_mean, self.v_s = update_counter(v_[idx], self.v_mean, self.v_s, self.k) | |
if self.k < 2: | |
raise RuntimeError(f"Can't perform normalization with only {self.k} samples so far") | |
self.u_std = (self.u_s / (self.k - 1)) ** 0.5 | |
self.v_std = (self.v_s / (self.k - 1)) ** 0.5 | |
if self.cfg.norm: | |
u_input = (u_ - self.u_mean) / (self.u_std + 1e-7) | |
v_input = (v_ - self.v_mean) / (self.v_std + 1e-7) | |
else: | |
u_input = u_ | |
v_input = v_ | |
if self.cfg.combine: | |
output = self.mlp(torch.cat((u_input, v_input), -1), mode=param_idx) | |
out1, out2 = output.split([u.shape[-1], v.shape[-1]], -1) | |
return out1, out2 | |
else: | |
return self.mlp1(u_input, mode=param_idx), self.mlp2(v_input, mode=param_idx) | |
class MEND(EditableModel): | |
def get_shape(self, p): | |
# We need to (annoyingly) flip the shapes since OpenAI gpt2 uses convs instead of linear | |
return p.shape if isinstance(self.model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0]) | |
def __init__(self, model, config, model_constructor, gtn=None, edit_lrs=None): | |
super().__init__(model, config, model_constructor) | |
if edit_lrs is None: | |
edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params))) | |
self.edit_lrs = edit_lrs | |
if not hasattr(self.model, "handles"): | |
hook_model(self.model, self.config.model.inner_params) | |
LOG.info(f"Hooked {len(self.model.handles)//2} modules") | |
if config.gtn.shared: | |
shape_dict = defaultdict(list) | |
for n, p in _inner_params(model.named_parameters(), self.config.model.inner_params): | |
shape_dict[self.get_shape(p)].append(n) | |
self.shape_dict = shape_dict | |
if gtn is None: | |
if not config.gtn.shared: | |
self.gtn = nn.ModuleDict({ | |
n.replace(".", "#"): GradientTransform(*self.get_shape(p), config.gtn) | |
for (n, p) in _inner_params(model.named_parameters(), self.config.model.inner_params) | |
}) | |
else: | |
self.gtn = nn.ModuleDict({ | |
str(tuple(s)): GradientTransform(*s, config.gtn, len(shape_dict[s])) | |
for s in shape_dict.keys() | |
}) | |
else: | |
self.gtn = gtn | |
def state_dict(self, destination=None, prefix="", keep_vars=False): | |
state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict | |
model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params | |
for k in model_keys: | |
del state_dict[f"model.{k}"] | |
state_dict["model_config"] = self.model.config # Include model config | |
return state_dict | |
def load_state_dict(self, state_dict, strict: bool = True): | |
config = state_dict["model_config"] | |
del state_dict["model_config"] | |
if config != self.model.config: | |
LOG.info("Loaded model config doesn't match current model config.") | |
LOG.info(f"Loaded: {config}") | |
LOG.info(f"Current: {self.model.config}") | |
res = super().load_state_dict(state_dict, False) | |
# We should only have missing keys for the model, and no unexpected keys | |
assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model." | |
assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys" | |
return res | |
def outer_parameters(self, grouped=False): | |
if grouped: | |
return [ | |
dict(params=list(self.gtn.parameters()), lr=self.config.lr), | |
dict(params=[self.edit_lrs], lr=self.config.lr_lr) | |
] | |
else: | |
return list(self.gtn.parameters()) + [self.edit_lrs] | |
def edit(self, batch, condition=None, detach_history=False): | |
outputs = _logits(self.model(**batch)) | |
loss = self.edit_loss_fn(outputs, batch["labels"])["nll"] | |
names = set([n for n, p in self.model.named_parameters()]) | |
pset = set(self.config.model.inner_params) | |
for p in pset: | |
assert p in names, f"inner param {p} not in model" | |
loss.backward() | |
if self.config.gtn.shared: | |
param_idx = lambda n, p: self.shape_dict[self.get_shape(p)].index(n) if self.config.gtn.shared else None # noqa: E731 | |
transformed_factors = { | |
n: self.gtn[str(tuple(self.get_shape(p)))](p.__x__, p.__delta__, param_idx(n, p)) | |
for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params) | |
} | |
else: | |
transformed_factors = { | |
n: self.gtn[n.replace(".", "#")](p.__x__, p.__delta__) | |
for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params) | |
} | |
# Should be bi,bj->ji for nn.Linear, but [annoying] GPT2 uses Conv1d instead... | |
if isinstance(self.model, transformers.GPT2LMHeadModel): | |
targ = "ij" | |
else: | |
targ = "ji" | |
mean_grads = { | |
n: torch.einsum(f"bi,bj->{targ}", x, delta) | |
for n, (x, delta) in transformed_factors.items() | |
} | |
info_dict = {} | |
idx = 0 | |
for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params): | |
info_dict[f"grad/true_mag{idx}"] = p.grad.norm(2).item() | |
info_dict[f"grad/pseudo_mag{idx}"] = mean_grads[n].norm(2).item() | |
info_dict[f"grad/true_std{idx}"] = p.grad.std().item() | |
info_dict[f"grad/pseudo_std{idx}"] = mean_grads[n].std().item() | |
info_dict[f"grad/diff{idx}"] = (p.grad - mean_grads[n]).norm(2).item() | |
info_dict[f"grad/cos{idx}"] = F.cosine_similarity(p.grad.reshape(-1), mean_grads[n].reshape(-1), dim=0).item() | |
idx += 1 | |
self.model.zero_grad() | |
assert len(self.edit_lrs) == len(list(mean_grads.items())) | |
updates = {n: lr * g for lr, (n, g) in zip(self.edit_lrs, mean_grads.items())} | |
edited_model = self.model | |
if not isinstance(edited_model, higher.patch._MonkeyPatchBase): | |
edited_model = make_functional(edited_model, in_place=True) | |
new_params = [] | |
for n, p in edited_model.named_parameters(): | |
if n in pset: | |
if self.config.gtn.descent: | |
new_params.append(p - updates[n]) | |
else: | |
new_params.append(p + updates[n]) | |
else: | |
new_params.append(p) | |
edited_model.update_params(new_params) | |
if detach_history: | |
new_model = self.model_constructor() | |
new_model.load_state_dict(edited_model.state_dict()) | |
edited_model = new_model | |
return MEND(edited_model, self.config, self.model_constructor, self.gtn, edit_lrs=self.edit_lrs), info_dict | |
if __name__ == '__main__': | |
import types | |
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") | |
config = types.SimpleNamespace() | |
config.model.inner_params = [ | |
"transformer.h.9.mlp.c_fc.weight", | |
"transformer.h.9.mlp.c_proj.weight", | |
"transformer.h.10.mlp.c_fc.weight", | |
"transformer.h.10.mlp.c_proj.weight", | |
"transformer.h.11.mlp.c_fc.weight", | |
"transformer.h.11.mlp.c_proj.weight", | |
] | |
config.edit_lr = 0.0001 | |
config.gtn = types.SimpleNamespace() | |
config.gtn.n_hidden = 1 | |
config.gtn = config.gtn.__dict__ | |
gtn = MEND(model, config, lambda: copy.deepcopy(model)).cuda() | |
# torch.save(gtn.state_dict(), "test_state.pt") | |
import pdb; pdb.set_trace() | |
gtn.load_state_dict(torch.load("test_state.pt")) | |
x = torch.arange(20).view(1, 20).cuda() + 1000 | |
orig_logits = gtn(x) | |
edited = gtn.edit(x, masks=torch.ones_like(x), labels=x) | |
post_logits = gtn(x) | |
assert torch.allclose(orig_logits, post_logits) | |
orig_param = [p for (n, p) in gtn.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
LOG.info((orig_param - edited_param).abs().max()) | |
edited.eval() | |
LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"]) | |
edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x) | |
LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss) | |