import torch
import copy
import transformers
import logging
from utils import scr, set_dropout, _logits, add_padding, add_sep
from editable_model import EditableModel
from models import BertClassifier
LOG = logging.getLogger(__name__)
def translate_tokens(tokens, from_tok, to_tok):
tokens = tokens.masked_fill(tokens == -100, from_tok.pad_token_id)
text = from_tok.batch_decode(tokens, skip_special_tokens=True)
return to_tok(text, return_tensors="pt")["input_ids"].to(tokens.device)
class SERAC(EditableModel):
def __init__(self, model, config, model_constructor, classifier=None, classifier_tok=None,
replacement=None, replacement_tok=None, cache_inputs=None, cache_labels=None,
cache_embeds=None, scale=None):
super().__init__(model, config, model_constructor)
if classifier is None:
if config.rep.cross_attend and not config.rep.cls_class.endswith("ForSequenceClassification"):
LOG.warn(f"Switching {config.rep.cls_class} to {config.rep.cls_class}ForSequenceClassification for cross-attend")
config.rep.cls_class += "ForSequenceClassification"
self.classifier = getattr(transformers, config.rep.cls_class).from_pretrained(config.rep.cls_name, cache_dir=scr())
if self.config.rep.checkpoint_grad:"Checking for checkpointing: {hasattr(self.classifier.config, 'gradient_checkpointing')}")
self.classifier.config.gradient_checkpointing = True
self.classifier_tok = transformers.AutoTokenizer.from_pretrained(config.rep.cls_name, cache_dir=scr())
if not self.config.rep.cross_attend and 'bert' in self.config.rep.cls_name:
self.classifier.pooler = None # we don't need the classification head
elif not self.config.rep.cross_attend and "mpnet" not in self.config.rep.cls_name:
if hasattr(self.classifier, "pooler"):
self.classifier.pooler = None # we don't need the classification head
set_dropout(self.classifier, config.dropout)
if self.config.rep.lora is not None:
self.classifier = LoraModel(self.classifier, self.config.rep.lora)
assert isinstance(classifier, torch.nn.Module), f"Classifier is a {type(classifier)}!"
assert isinstance(classifier_tok, transformers.PreTrainedTokenizerBase), f"Classifier tok is {type(classifier_tok)}!"
self.classifier, self.classifier_tok = classifier, classifier_tok
if replacement is None:
# self.replacement_tok = getattr(transformers, config.model.tokenizer_class).from_pretrained(config.model.tokenizer_name,
# cache_dir=scr())
self.replacement_tok = transformers.AutoTokenizer.from_pretrained(config.model.small_name, cache_dir=scr())
# if self.replacement_tok.sep_token is None:
# self.replacement_tok.sep_token = self.replacement_tok.eos_token
if (False and self.config.rep.freeze_cntr):
self.replacement = None
if config.model.class_name == "BertClassifier":
self.replacement = BertClassifier(config.model.small_name)
self.replacement = getattr(transformers, config.model.class_name).from_pretrained(config.model.small_name, cache_dir=scr())
if self.replacement_tok.sep_token is None and "gpt" not in self.model.name_or_path.lower():
add_sep(self.replacement_tok, self.replacement)
if self.replacement_tok.pad_token is None:
add_padding(self.replacement_tok, self.replacement)
set_dropout(self.replacement, config.dropout)
assert isinstance(replacement, torch.nn.Module), "Rep is {type(replacement)}!"
assert isinstance(replacement_tok, transformers.PreTrainedTokenizerBase), "Rep tok is {type(replacement_tok)}!"
self.replacement, self.replacement_tok = replacement, replacement_tok
if self.config.rep.cross_attend:
self.scale = None
if scale is None:
self.register_buffer("scale", torch.tensor(1.0))
# self.scale = nn.Parameter(torch.tensor(1.0))
self.scale = scale
if cache_inputs is None:
self.cache_inputs = []
self.cache_labels = []
if config.rep.cache_embeds and not config.rep.cross_attend:
self.cache_embeds = {}
assert isinstance(cache_inputs, list), f"Cache inputs is {cache_inputs}"
assert isinstance(cache_labels, list), f"Cache labels is {cache_labels}"
self.cache_inputs = copy.deepcopy(cache_inputs)
self.cache_labels = copy.deepcopy(cache_labels)
if config.rep.cache_embeds and not config.rep.cross_attend:
assert isinstance(cache_embeds, dict), f"Cache embeds is {cache_embeds}"
self.cache_embeds = copy.deepcopy(cache_embeds)
def state_dict(self, destination=None, prefix="", keep_vars=False):
state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict
model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params
for k in model_keys:
del state_dict[f"model.{k}"]
if self.config.rep.freeze_cntr:
cntr_keys = self.replacement.state_dict().keys()
for k in cntr_keys:
del state_dict[f"replacement.{k}"]
state_dict["model_config"] = self.model.config # Include model config
return state_dict
def load_state_dict(self, state_dict, strict: bool = True):
config = state_dict["model_config"]
del state_dict["model_config"]
if config != self.model.config:"Loaded model config doesn't match current model config.")"Loaded: {config}")"Current: {self.model.config}")
if (False and self.config.rep.freeze_cntr):
rep_keys = list(state_dict.keys())
for k in rep_keys:
if k.startswith("replacement"):
del state_dict[k]
res = super().load_state_dict(state_dict, False)
res = super().load_state_dict(state_dict, False)
except RuntimeError:"Load failed; trying again without loading counterfactual model weights.")
rep_keys = list(state_dict.keys())
for k in rep_keys:
if k.startswith("replacement"):
del state_dict[k]
res = super().load_state_dict(state_dict, False)
# We should only have missing keys for the model, and no unexpected keys
def ok_to_miss(k):
return k.startswith("model.") or ((False and self.config.rep.freeze_cntr) and k.startswith("replacement."))
missing_keys = [k for k in res.missing_keys if not ok_to_miss(k)]
assert len(missing_keys) == 0, f"Should only have missing keys for model: {missing_keys}."
assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys"
return res
def outer_parameters(self, grouped=False):
if self.config.rep.freeze is not None:
modlist = None
for m in self.classifier.modules():
if isinstance(m, torch.nn.ModuleList):
modlist = m
model_params = list(modlist[-self.config.rep.freeze:].parameters())
model_params = list(self.classifier.parameters())
if self.config.rep.lora is not None or self.config.rep.freeze is not None:
cls = self.classifier.base_model if self.config.rep.lora else self.classifier
if hasattr(cls, "classifier"):
if hasattr(cls, "pre_classifier"):
if not (False and self.config.rep.freeze_cntr):
extra_params = []
if grouped:
return [
dict(params=extra_params, lr=self.config.lr_lr)
return model_params + extra_params
def edit(self, batch, condition=None, detach_history=False):
def detokenize(toks, tok):
tokens = toks.masked_fill(toks == -100, tok.pad_token_id)
return tok.batch_decode(tokens, skip_special_tokens=True)
inputs = detokenize(batch["input_ids"], self.replacement_tok)
if "bert" in
labels = ["" for _ in batch["labels"]]
labels = detokenize(batch["labels"], self.replacement_tok)
cache_inputs = self.cache_inputs + inputs
cache_labels = self.cache_labels + labels
if self.config.rep.cache_embeds and not self.config.rep.cross_attend:
cls_inputs = self.build_cls_cache_inputs(inputs, labels)
with torch.no_grad():
embeds = self.compute_cls_embeddings(cls_inputs)
cache_embeds = {inp: emb for inp, emb in zip(cls_inputs, embeds)}
cache_embeds = None
new_model = SERAC(self.model, self.config, self.model_constructor, self.classifier, self.classifier_tok,
self.replacement, self.replacement_tok, cache_inputs, cache_labels, cache_embeds, self.scale)
return new_model, {}
def stats(self):
return self.last_stats
def compute_cls_embeddings(self, text):
inputs = self.classifier_tok(text, return_tensors="pt", padding=True).to(self.config.device)
if 'bert' in self.config.rep.cls_name:
embeds = self.classifier(**inputs).last_hidden_state[:, 0].unsqueeze(1)
embeds = self.classifier(**inputs).pooler_output.unsqueeze(1)
embeds = embeds.view(embeds.shape[0], self.config.rep.dist_heads, -1)
if self.config.rep.bound_embeds:
embeds = embeds.tanh()
return embeds
def embedding_logsim_matrix(self, cls_ctxs, test_input_text):
if self.config.rep.cache_embeds and not self.config.rep.cross_attend and not
ctx_embeds =[self.cache_embeds[ctx] for ctx in cls_ctxs])
ctx_embeds = self.compute_cls_embeddings(cls_ctxs)
main_embeds = self.compute_cls_embeddings(test_input_text)
if self.config.rep.cos:
cos = (ctx_embeds[None] * main_embeds[:, None]).sum(-1) / (ctx_embeds[None].norm(2, -1) * main_embeds[:, None].norm(2, -1))
dists = 1 - cos
dists = (ctx_embeds[None] - main_embeds[:, None]).norm(2, -1)
if self.config.rep.square:
dists = dists ** 2
dists = dists.min(-1).values # get rid of the dists head dimension
assert dists.min() >= 0, "Shouldn't have negative distances!"
cls_logsims = -dists * self.scale
return cls_logsims
def crossattend_logsim_matrix(self, cls_ctxs, test_input_texts):
batch = [ctx + self.classifier_tok.sep_token + test for test in test_input_texts for ctx in cls_ctxs]
batch_toks = self.classifier_tok(batch, return_tensors="pt", padding=True).to(self.config.device)
batch_logsims = self.classifier(**batch_toks).logits.log_softmax(-1)[:, 0]
logsim_matrix = batch_logsims.view(len(test_input_texts), len(cls_ctxs))
return logsim_matrix
def build_rep_cache_contexts(self):
sep = " "
if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
# The labels are include in the inputs for autoregressive models. Cut off the label for the classifier
ctxs = [cin + sep for cin in self.cache_inputs]
ctxs = [cin + sep + clab + sep for cin, clab in zip(self.cache_inputs, self.cache_labels)]
return ctxs
def build_cls_cache_inputs(self, cache_inputs=None, cache_labels=None):
sep = self.classifier_tok.sep_token
if cache_inputs is None:
cache_inputs = self.cache_inputs
if cache_labels is None:
cache_labels = self.cache_labels
if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
# The labels are include in the inputs for autoregressive models. Cut off the label for the classifier
inputs = [cin.rsplit(" ", 1)[0] + sep for cin in cache_inputs]
inputs = [cin + sep + clab + sep for cin, clab in zip(cache_inputs, cache_labels)]
return inputs
def build_rep_input_tokens(self, kwargs, idxs, generation=False):
assert len(idxs) == len(kwargs["input_ids"]), "Need one cache idx for each test input"
cache_contexts = self.build_rep_cache_contexts()
selected_contexts = [cache_contexts[idx.item()] for idx in idxs]
test_inputs = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True)
rep_texts = [ctx + inp for ctx, inp in zip(selected_contexts, test_inputs)]
rep_input_tokens = self.replacement_tok(rep_texts, return_tensors="pt", padding=True).to(self.config.device)
rep_kwargs = {
"input_ids": rep_input_tokens["input_ids"],
"attention_mask": rep_input_tokens["attention_mask"],
if not generation:
rep_kwargs["labels"] = kwargs["labels"]
# if self.config.task in ["fc", "fnli"]:
# del rep_kwargs["labels"]
if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
# Add 'ignore' labels for the prepended cache inputs
pre = torch.full((kwargs["labels"].shape[0], rep_kwargs["input_ids"].shape[-1] - kwargs["labels"].shape[-1]), -100,
rep_kwargs["labels"] =, kwargs["labels"]), dim=-1)
return rep_kwargs
def run_classifier(self, *inputs, **kwargs):
cache_inputs = self.build_cls_cache_inputs()
test_inputs = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True)
if self.config.rep.cross_attend:
log_sim_matrix = self.crossattend_logsim_matrix(cache_inputs, test_inputs)
log_sim_matrix = self.embedding_logsim_matrix(cache_inputs, test_inputs)
sims = log_sim_matrix.exp()
assert sims.max() <= 1, "Similarities shouldn't exceed 1!"
cls_sims, cls_idxs = sims.max(-1)
return cls_sims, cls_idxs, log_sim_matrix
def generate(self, *args, **kwargs):
# input_text = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True)
if "max_new_tokens" not in kwargs:
kwargs["max_new_tokens"] = 20
base_generate_fn = (
self.model.forward if type(self.model) == BertClassifier
else lambda *args, **kwargs: self.model.generate(*args, **kwargs)
cntr_generate_fn = (
self.replacement.forward if type(self.replacement) == BertClassifier
else lambda *args, **kwargs: self.replacement.generate(*args, **kwargs)
# assert len(args) == 0, "Should only pass named arguments to generate()"
if len(self.cache_inputs) > 0:
override = kwargs.get("override")
if override:
del kwargs["override"]
cls_sims, cls_idxs, _ = self.run_classifier(*args, **kwargs)
# assert cls_sims.numel() == 1
# print(f"Cache score: {cls_sims.item()} " + ("[MISS]" if cls_sims.item() < 0.5 else "[HIT]"))
use_cntr = (override == "cntr") if override is not None else (cls_sims.item() > 0.5)
if use_cntr:
rep_input = self.build_rep_input_tokens(kwargs, cls_idxs, generation=True)
kwargs["input_ids"] = rep_input["input_ids"]
kwargs["attention_mask"] = rep_input["attention_mask"]
# rep_input_text = self.replacement_tok.decode(rep_input["input_ids"][0])
# print(f"Returning counterfactual model output for '{rep_input_text}'")
if self.config.rep.freeze_cntr:
return base_generate_fn(*args, **kwargs)
return cntr_generate_fn(*args, **kwargs)
# print(f"Returning base model output for '{input_text}'")
return base_generate_fn(*args, **kwargs)
def forward(self, *inputs, return_logits_only=True, eps=torch.finfo(torch.float32).eps, pos_pairs=None, **kwargs):
grad_enabled = torch.is_grad_enabled()
# need to do soft mixing of logits if we're doing supervised training or we've specifically requested it
soft = (not self.config.rep.supervised) or self.config.rep.soft_weighting
with torch.no_grad():
if len(self.cache_inputs) == 0:
super_out = super().forward(*inputs, **kwargs).float()
return super_out
base_logits = super().forward(*inputs, **kwargs).float()
if soft:
if base_logits.dim() == 3:
base_probs = base_logits.softmax(-1)
base_probs = base_logits.sigmoid()
del base_logits
cls_sims, cls_idxs, cls_logits = self.run_classifier(*inputs, **kwargs)
rep_cls_inputs = self.build_rep_input_tokens(kwargs, cls_idxs)
if self.config.rep.freeze_cntr:
rep_cls_logits = _logits(super().forward(**rep_cls_inputs))
rep_cls_logits = _logits(self.replacement(**rep_cls_inputs))
if pos_pairs is not None:
assert (pos_pairs[:, 0] == torch.arange(pos_pairs.shape[0], device=pos_pairs.device)).all()
gold_idxs = pos_pairs[:, 1]
# print("IDX acc:", (cls_idxs == gold_idxs).shape, (cls_idxs == gold_idxs).float().mean())
rep_gold_inputs = self.build_rep_input_tokens(kwargs, gold_idxs)
if (False and self.config.rep.freeze_cntr):
rep_gold_logits = _logits(super().forward(**rep_gold_inputs))
rep_gold_logits = _logits(self.replacement(**rep_gold_inputs))
rep_gold_logits = rep_cls_logits
cls_sims = cls_sims.view(-1, 1) # For (binary) classification, predictions are (B x 1)
if rep_cls_logits.dim() == 3:
cls_sims.unsqueeze_(-1) # For generation/seq2seq, predictions are (B x S x V)
stats = {
'sims/mean': cls_sims.mean().item(),
'sims/pos': (cls_sims >= 0.5).float().mean().item(),
'sims/neg': (cls_sims < 0.5).float().mean().item(),
'params/scale': self.scale.item() if self.scale is not None else 0.0,
if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
rep_cls_logits = rep_cls_logits[:, -kwargs["labels"].shape[-1]:, :]
if soft:
rep_weight = cls_sims
if base_probs.dim() == 3:
mixture_logits = ((1 - rep_weight) * base_probs + rep_weight * rep_cls_logits.softmax(-1) + eps).log()
mixture_logits = ((1 - rep_weight) * base_probs + rep_weight * rep_cls_logits.sigmoid() + eps).log()
rep_idxs = torch.where(cls_sims > 0.5)[0]
mixture_logits = base_logits
if rep_idxs.numel() > 0:
mixture_logits[rep_idxs] = rep_cls_logits[rep_idxs]
if return_logits_only:
return mixture_logits
return mixture_logits, cls_logits, rep_gold_logits, stats
if __name__ == '__main__':
import types
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
config = types.SimpleNamespace()
config.model.inner_params = [
config.edit_lr = 0.0001
config.gtn = types.SimpleNamespace()
config.gtn.n_hidden = 1
config.gtn = config.gtn.__dict__
gtn = SERAC(model, config, lambda: copy.deepcopy(model)).cuda()
#, "")
import pdb; pdb.set_trace()
x = torch.arange(20).view(1, 20).cuda() + 1000
orig_logits = gtn(x)
edited = gtn.edit(x, masks=torch.ones_like(x), labels=x)
post_logits = gtn(x)
assert torch.allclose(orig_logits, post_logits)
orig_param = [p for (n, p) in gtn.model.named_parameters() if n == config.model.inner_params[-1]][0]
edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] - edited_param).abs().max())
edited.eval(), labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"])
edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x), labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss)