import torch import copy import transformers import logging from utils import scr, set_dropout, _logits, add_padding, add_sep from editable_model import EditableModel from models import BertClassifier LOG = logging.getLogger(__name__) def translate_tokens(tokens, from_tok, to_tok): tokens = tokens.masked_fill(tokens == -100, from_tok.pad_token_id) text = from_tok.batch_decode(tokens, skip_special_tokens=True) return to_tok(text, return_tensors="pt")["input_ids"].to(tokens.device) class SERAC(EditableModel): def __init__(self, model, config, model_constructor, classifier=None, classifier_tok=None, replacement=None, replacement_tok=None, cache_inputs=None, cache_labels=None, cache_embeds=None, scale=None): super().__init__(model, config, model_constructor) if classifier is None: if config.rep.cross_attend and not config.rep.cls_class.endswith("ForSequenceClassification"): LOG.warn(f"Switching {config.rep.cls_class} to {config.rep.cls_class}ForSequenceClassification for cross-attend") config.rep.cls_class += "ForSequenceClassification" self.classifier = getattr(transformers, config.rep.cls_class).from_pretrained(config.rep.cls_name, cache_dir=scr()) if self.config.rep.checkpoint_grad: LOG.info(f"Checking for checkpointing: {hasattr(self.classifier.config, 'gradient_checkpointing')}") self.classifier.config.gradient_checkpointing = True self.classifier_tok = transformers.AutoTokenizer.from_pretrained(config.rep.cls_name, cache_dir=scr()) if not self.config.rep.cross_attend and 'bert' in self.config.rep.cls_name: self.classifier.pooler = None # we don't need the classification head elif not self.config.rep.cross_attend and "mpnet" not in self.config.rep.cls_name: if hasattr(self.classifier, "pooler"): self.classifier.pooler = None # we don't need the classification head set_dropout(self.classifier, config.dropout) if self.config.rep.lora is not None: self.classifier = LoraModel(self.classifier, self.config.rep.lora) else: assert isinstance(classifier, torch.nn.Module), f"Classifier is a {type(classifier)}!" assert isinstance(classifier_tok, transformers.PreTrainedTokenizerBase), f"Classifier tok is {type(classifier_tok)}!" self.classifier, self.classifier_tok = classifier, classifier_tok if replacement is None: # self.replacement_tok = getattr(transformers, config.model.tokenizer_class).from_pretrained(config.model.tokenizer_name, # cache_dir=scr()) self.replacement_tok = transformers.AutoTokenizer.from_pretrained(config.model.small_name, cache_dir=scr()) # if self.replacement_tok.sep_token is None: # self.replacement_tok.sep_token = self.replacement_tok.eos_token if (False and self.config.rep.freeze_cntr): self.replacement = None else: if config.model.class_name == "BertClassifier": self.replacement = BertClassifier(config.model.small_name) else: self.replacement = getattr(transformers, config.model.class_name).from_pretrained(config.model.small_name, cache_dir=scr()) if self.replacement_tok.sep_token is None and "gpt" not in self.model.name_or_path.lower(): add_sep(self.replacement_tok, self.replacement) if self.replacement_tok.pad_token is None: add_padding(self.replacement_tok, self.replacement) set_dropout(self.replacement, config.dropout) else: assert isinstance(replacement, torch.nn.Module), "Rep is {type(replacement)}!" assert isinstance(replacement_tok, transformers.PreTrainedTokenizerBase), "Rep tok is {type(replacement_tok)}!" self.replacement, self.replacement_tok = replacement, replacement_tok if self.config.rep.cross_attend: self.scale = None else: if scale is None: self.register_buffer("scale", torch.tensor(1.0)) # self.scale = nn.Parameter(torch.tensor(1.0)) else: self.scale = scale if cache_inputs is None: self.cache_inputs = [] self.cache_labels = [] if config.rep.cache_embeds and not config.rep.cross_attend: self.cache_embeds = {} else: assert isinstance(cache_inputs, list), f"Cache inputs is {cache_inputs}" assert isinstance(cache_labels, list), f"Cache labels is {cache_labels}" self.cache_inputs = copy.deepcopy(cache_inputs) self.cache_labels = copy.deepcopy(cache_labels) if config.rep.cache_embeds and not config.rep.cross_attend: assert isinstance(cache_embeds, dict), f"Cache embeds is {cache_embeds}" self.cache_embeds = copy.deepcopy(cache_embeds) def state_dict(self, destination=None, prefix="", keep_vars=False): state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params for k in model_keys: del state_dict[f"model.{k}"] if self.config.rep.freeze_cntr: cntr_keys = self.replacement.state_dict().keys() for k in cntr_keys: del state_dict[f"replacement.{k}"] state_dict["model_config"] = self.model.config # Include model config return state_dict def load_state_dict(self, state_dict, strict: bool = True): config = state_dict["model_config"] del state_dict["model_config"] if config != self.model.config: LOG.info("Loaded model config doesn't match current model config.") LOG.info(f"Loaded: {config}") LOG.info(f"Current: {self.model.config}") if (False and self.config.rep.freeze_cntr): rep_keys = list(state_dict.keys()) for k in rep_keys: if k.startswith("replacement"): del state_dict[k] res = super().load_state_dict(state_dict, False) else: try: res = super().load_state_dict(state_dict, False) except RuntimeError: LOG.info("Load failed; trying again without loading counterfactual model weights.") rep_keys = list(state_dict.keys()) for k in rep_keys: if k.startswith("replacement"): del state_dict[k] res = super().load_state_dict(state_dict, False) # We should only have missing keys for the model, and no unexpected keys def ok_to_miss(k): return k.startswith("model.") or ((False and self.config.rep.freeze_cntr) and k.startswith("replacement.")) missing_keys = [k for k in res.missing_keys if not ok_to_miss(k)] assert len(missing_keys) == 0, f"Should only have missing keys for model: {missing_keys}." assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys" return res def outer_parameters(self, grouped=False): if self.config.rep.freeze is not None: modlist = None for m in self.classifier.modules(): if isinstance(m, torch.nn.ModuleList): modlist = m break model_params = list(modlist[-self.config.rep.freeze:].parameters()) else: model_params = list(self.classifier.parameters()) if self.config.rep.lora is not None or self.config.rep.freeze is not None: cls = self.classifier.base_model if self.config.rep.lora else self.classifier if hasattr(cls, "classifier"): model_params.extend(cls.classifier.parameters()) if hasattr(cls, "pre_classifier"): model_params.extend(cls.pre_classifier.parameters()) if not (False and self.config.rep.freeze_cntr): model_params.extend(list(self.replacement.parameters())) extra_params = [] if grouped: return [ dict(params=model_params, lr=self.config.lr), dict(params=extra_params, lr=self.config.lr_lr) ] else: return model_params + extra_params def edit(self, batch, condition=None, detach_history=False): def detokenize(toks, tok): tokens = toks.masked_fill(toks == -100, tok.pad_token_id) return tok.batch_decode(tokens, skip_special_tokens=True) inputs = detokenize(batch["input_ids"], self.replacement_tok) if "bert" in self.config.model.name: labels = ["" for _ in batch["labels"]] else: labels = detokenize(batch["labels"], self.replacement_tok) cache_inputs = self.cache_inputs + inputs cache_labels = self.cache_labels + labels if self.config.rep.cache_embeds and not self.config.rep.cross_attend: cls_inputs = self.build_cls_cache_inputs(inputs, labels) with torch.no_grad(): embeds = self.compute_cls_embeddings(cls_inputs) cache_embeds = {inp: emb for inp, emb in zip(cls_inputs, embeds)} cache_embeds.update(self.cache_embeds) else: cache_embeds = None new_model = SERAC(self.model, self.config, self.model_constructor, self.classifier, self.classifier_tok, self.replacement, self.replacement_tok, cache_inputs, cache_labels, cache_embeds, self.scale) new_model.train(self.training) return new_model, {} def stats(self): return self.last_stats def compute_cls_embeddings(self, text): inputs = self.classifier_tok(text, return_tensors="pt", padding=True).to(self.config.device) if 'bert' in self.config.rep.cls_name: embeds = self.classifier(**inputs).last_hidden_state[:, 0].unsqueeze(1) else: embeds = self.classifier(**inputs).pooler_output.unsqueeze(1) embeds = embeds.view(embeds.shape[0], self.config.rep.dist_heads, -1) if self.config.rep.bound_embeds: embeds = embeds.tanh() return embeds def embedding_logsim_matrix(self, cls_ctxs, test_input_text): if self.config.rep.cache_embeds and not self.config.rep.cross_attend and not self.training: ctx_embeds = torch.cat([self.cache_embeds[ctx] for ctx in cls_ctxs]) else: ctx_embeds = self.compute_cls_embeddings(cls_ctxs) main_embeds = self.compute_cls_embeddings(test_input_text) if self.config.rep.cos: cos = (ctx_embeds[None] * main_embeds[:, None]).sum(-1) / (ctx_embeds[None].norm(2, -1) * main_embeds[:, None].norm(2, -1)) dists = 1 - cos else: dists = (ctx_embeds[None] - main_embeds[:, None]).norm(2, -1) if self.config.rep.square: dists = dists ** 2 dists = dists.min(-1).values # get rid of the dists head dimension assert dists.min() >= 0, "Shouldn't have negative distances!" cls_logsims = -dists * self.scale return cls_logsims def crossattend_logsim_matrix(self, cls_ctxs, test_input_texts): batch = [ctx + self.classifier_tok.sep_token + test for test in test_input_texts for ctx in cls_ctxs] batch_toks = self.classifier_tok(batch, return_tensors="pt", padding=True).to(self.config.device) batch_logsims = self.classifier(**batch_toks).logits.log_softmax(-1)[:, 0] logsim_matrix = batch_logsims.view(len(test_input_texts), len(cls_ctxs)) return logsim_matrix def build_rep_cache_contexts(self): sep = " " if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower(): # The labels are include in the inputs for autoregressive models. Cut off the label for the classifier ctxs = [cin + sep for cin in self.cache_inputs] else: ctxs = [cin + sep + clab + sep for cin, clab in zip(self.cache_inputs, self.cache_labels)] return ctxs def build_cls_cache_inputs(self, cache_inputs=None, cache_labels=None): sep = self.classifier_tok.sep_token if cache_inputs is None: cache_inputs = self.cache_inputs if cache_labels is None: cache_labels = self.cache_labels if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower(): # The labels are include in the inputs for autoregressive models. Cut off the label for the classifier inputs = [cin.rsplit(" ", 1)[0] + sep for cin in cache_inputs] else: inputs = [cin + sep + clab + sep for cin, clab in zip(cache_inputs, cache_labels)] return inputs def build_rep_input_tokens(self, kwargs, idxs, generation=False): assert len(idxs) == len(kwargs["input_ids"]), "Need one cache idx for each test input" cache_contexts = self.build_rep_cache_contexts() selected_contexts = [cache_contexts[idx.item()] for idx in idxs] test_inputs = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True) rep_texts = [ctx + inp for ctx, inp in zip(selected_contexts, test_inputs)] rep_input_tokens = self.replacement_tok(rep_texts, return_tensors="pt", padding=True).to(self.config.device) rep_kwargs = { "input_ids": rep_input_tokens["input_ids"], "attention_mask": rep_input_tokens["attention_mask"], } if not generation: rep_kwargs["labels"] = kwargs["labels"] # if self.config.task in ["fc", "fnli"]: # del rep_kwargs["labels"] if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower(): # Add 'ignore' labels for the prepended cache inputs pre = torch.full((kwargs["labels"].shape[0], rep_kwargs["input_ids"].shape[-1] - kwargs["labels"].shape[-1]), -100, device=kwargs["labels"].device) rep_kwargs["labels"] = torch.cat((pre, kwargs["labels"]), dim=-1) return rep_kwargs def run_classifier(self, *inputs, **kwargs): cache_inputs = self.build_cls_cache_inputs() test_inputs = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True) if self.config.rep.cross_attend: log_sim_matrix = self.crossattend_logsim_matrix(cache_inputs, test_inputs) else: log_sim_matrix = self.embedding_logsim_matrix(cache_inputs, test_inputs) sims = log_sim_matrix.exp() assert sims.max() <= 1, "Similarities shouldn't exceed 1!" cls_sims, cls_idxs = sims.max(-1) return cls_sims, cls_idxs, log_sim_matrix def generate(self, *args, **kwargs): # input_text = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True) if "max_new_tokens" not in kwargs: kwargs["max_new_tokens"] = 20 base_generate_fn = ( self.model.forward if type(self.model) == BertClassifier else lambda *args, **kwargs: self.model.generate(*args, **kwargs) ) cntr_generate_fn = ( self.replacement.forward if type(self.replacement) == BertClassifier else lambda *args, **kwargs: self.replacement.generate(*args, **kwargs) ) # assert len(args) == 0, "Should only pass named arguments to generate()" if len(self.cache_inputs) > 0: override = kwargs.get("override") if override: del kwargs["override"] cls_sims, cls_idxs, _ = self.run_classifier(*args, **kwargs) # assert cls_sims.numel() == 1 # print(f"Cache score: {cls_sims.item()} " + ("[MISS]" if cls_sims.item() < 0.5 else "[HIT]")) use_cntr = (override == "cntr") if override is not None else (cls_sims.item() > 0.5) if use_cntr: rep_input = self.build_rep_input_tokens(kwargs, cls_idxs, generation=True) kwargs["input_ids"] = rep_input["input_ids"] kwargs["attention_mask"] = rep_input["attention_mask"] # rep_input_text = self.replacement_tok.decode(rep_input["input_ids"][0]) # print(f"Returning counterfactual model output for '{rep_input_text}'") if self.config.rep.freeze_cntr: return base_generate_fn(*args, **kwargs) else: return cntr_generate_fn(*args, **kwargs) # print(f"Returning base model output for '{input_text}'") return base_generate_fn(*args, **kwargs) def forward(self, *inputs, return_logits_only=True, eps=torch.finfo(torch.float32).eps, pos_pairs=None, **kwargs): grad_enabled = torch.is_grad_enabled() torch.set_grad_enabled(self.training) # need to do soft mixing of logits if we're doing supervised training or we've specifically requested it soft = (not self.config.rep.supervised) or self.config.rep.soft_weighting with torch.no_grad(): if len(self.cache_inputs) == 0: super_out = super().forward(*inputs, **kwargs).float() torch.set_grad_enabled(grad_enabled) return super_out else: base_logits = super().forward(*inputs, **kwargs).float() if soft: if base_logits.dim() == 3: base_probs = base_logits.softmax(-1) else: base_probs = base_logits.sigmoid() del base_logits cls_sims, cls_idxs, cls_logits = self.run_classifier(*inputs, **kwargs) rep_cls_inputs = self.build_rep_input_tokens(kwargs, cls_idxs) if self.config.rep.freeze_cntr: rep_cls_logits = _logits(super().forward(**rep_cls_inputs)) else: rep_cls_logits = _logits(self.replacement(**rep_cls_inputs)) if pos_pairs is not None: assert (pos_pairs[:, 0] == torch.arange(pos_pairs.shape[0], device=pos_pairs.device)).all() gold_idxs = pos_pairs[:, 1] # print("IDX acc:", (cls_idxs == gold_idxs).shape, (cls_idxs == gold_idxs).float().mean()) rep_gold_inputs = self.build_rep_input_tokens(kwargs, gold_idxs) if (False and self.config.rep.freeze_cntr): rep_gold_logits = _logits(super().forward(**rep_gold_inputs)) else: rep_gold_logits = _logits(self.replacement(**rep_gold_inputs)) else: rep_gold_logits = rep_cls_logits cls_sims = cls_sims.view(-1, 1) # For (binary) classification, predictions are (B x 1) if rep_cls_logits.dim() == 3: cls_sims.unsqueeze_(-1) # For generation/seq2seq, predictions are (B x S x V) stats = { 'sims/mean': cls_sims.mean().item(), 'sims/pos': (cls_sims >= 0.5).float().mean().item(), 'sims/neg': (cls_sims < 0.5).float().mean().item(), 'params/scale': self.scale.item() if self.scale is not None else 0.0, } if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower(): rep_cls_logits = rep_cls_logits[:, -kwargs["labels"].shape[-1]:, :] if soft: rep_weight = cls_sims if base_probs.dim() == 3: mixture_logits = ((1 - rep_weight) * base_probs + rep_weight * rep_cls_logits.softmax(-1) + eps).log() else: mixture_logits = ((1 - rep_weight) * base_probs + rep_weight * rep_cls_logits.sigmoid() + eps).log() else: rep_idxs = torch.where(cls_sims > 0.5)[0] mixture_logits = base_logits if rep_idxs.numel() > 0: mixture_logits[rep_idxs] = rep_cls_logits[rep_idxs] torch.set_grad_enabled(grad_enabled) if return_logits_only: return mixture_logits else: return mixture_logits, cls_logits, rep_gold_logits, stats if __name__ == '__main__': import types model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") config = types.SimpleNamespace() config.model.inner_params = [ "transformer.h.9.mlp.c_fc.weight", "transformer.h.9.mlp.c_proj.weight", "transformer.h.10.mlp.c_fc.weight", "transformer.h.10.mlp.c_proj.weight", "transformer.h.11.mlp.c_fc.weight", "transformer.h.11.mlp.c_proj.weight", ] config.edit_lr = 0.0001 config.gtn = types.SimpleNamespace() config.gtn.n_hidden = 1 config.gtn = config.gtn.__dict__ gtn = SERAC(model, config, lambda: copy.deepcopy(model)).cuda() # torch.save(gtn.state_dict(), "test_state.pt") import pdb; pdb.set_trace() gtn.load_state_dict(torch.load("test_state.pt")) x = torch.arange(20).view(1, 20).cuda() + 1000 orig_logits = gtn(x) edited = gtn.edit(x, masks=torch.ones_like(x), labels=x) post_logits = gtn(x) assert torch.allclose(orig_logits, post_logits) orig_param = [p for (n, p) in gtn.model.named_parameters() if n == config.model.inner_params[-1]][0] edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] LOG.info((orig_param - edited_param).abs().max()) edited.eval() LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"]) edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x) LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss)