Confused about bidirectional attention when implementing custom sampling loop
I'm trying to implement a custom sampling loop for GPT-JT, because I need some features not supported by model.generate
. However, I'm a bit confused about how the bidirectional attention mask is tracked. Can someone point me to the code when GPT-JT bidirectional vs causal masking is controlled?
In this answer,
@juewang
mentions that the causal attention mask for GPT-JT is set to 1 by default. However, loading GPT-JT with transformers.AutoModelForCausalLM.from_pretrained
just loads a normal GPT-J model, and the attention bias for GPT-J defaults to causal attention, as far as I can tell from here.
Could someone explain what I'm missing? I'm confused about how GPT-JT can implement custom attention masking, when there doesn't seem to be any GPT-JT-specific code in HuggingFace (just relying on GPT-J).
Thanks!
I was confused because I didn't realize that the attention_mask is actually a PyTorch registered buffer, i.e., part of the weights checkpoint; it's not controlled in code. The mask is in model.transformer.h[i].attn.bias.data[:]
.
My simple sampling loop looks like this, for reference:
def gptjt_sample(model, tokenizer, prompt_text, max_length=100, eos_token_id=None, do_sample=False):
dev = list(model.parameters())[0].device
input_ids = tokenizer(prompt_text, return_tensors='pt').input_ids.to(dev)
past_key_values = None
output_ids = input_ids
for i in range(max_length):
possibly_only_last_token = output_ids[:, -1:] if past_key_values is not None else output_ids
outputs = model(possibly_only_last_token, use_cache=True, past_key_values=past_key_values, output_hidden_states=True)
past_key_values = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]
if do_sample:
next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
output_ids = torch.cat([output_ids, next_token], dim=-1)
if eos_token_id is not None and next_token == eos_token_id:
break
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
Yeah, you are right, attention_mask
is a registered buffer and will be overwritten after loading the ckpt.