ReactXT / model /blip2_opt.py
SyrWin
init
95f97c5
"""
Copyright (c) 2023, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import torch
import torch.nn as nn
from torch.cuda.amp import autocast as autocast
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel
from ogb.utils import smiles2graph
from torch_geometric.loader.dataloader import Collater
from torch_geometric.data import Data
import numpy as np
from lavis.models.blip2_models.blip2 import (
# Blip2Base,
disabled_train,
)
from model.blip2 import Blip2Base
from model.help_funcs import get_not_allowed_tokens_ids
from transformers import AutoTokenizer
from transformers import OPTForCausalLM, OPTConfig
# from opendelta import LoraModel
# from opendelta.delta_models.lora import LoraConfig
# from opendelta.delta_configs
opt_model_list = [
"facebook/galactica-125m",
"facebook/galactica-1.3b",
"facebook/galactica-6.7b",
"facebook/galactica-30b",
]
def mask_by_len(input, lens, fill_value=0):
'''
input: shape = [N, D]
lens: shape = [N]
'''
mask = torch.arange(input.shape[1], device=input.device).reshape(1, -1)
mask = mask < lens.reshape(-1, 1)
input[mask] = fill_value
return input
def smiles2data(smiles):
graph = smiles2graph(smiles)
x = torch.from_numpy(graph['node_feat'])
edge_index = torch.from_numpy(graph['edge_index'], )
edge_attr = torch.from_numpy(graph['edge_feat'])
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
return data
import re
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
def _insert_split_marker(m: re.Match):
"""
Applies split marker based on a regex match of special tokens such as
[START_DNA].
Parameters
----------
n : str
Input text to split
Returns
----------
str - the text with the split token added
"""
start_token, _, sequence, end_token = m.groups()
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
def escape_custom_split_sequence(text):
"""
Applies custom splitting to the text for GALILEO's tokenization
Parameters
----------
text : str
Input text to split
Returns
----------
str - the text with the split token added
"""
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
def smiles_handler(text, mol_ph):
smiles_list = []
for match in CUSTOM_SEQ_RE.finditer(text):
smiles = match.group(3)
smiles_list.append(smiles)
text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text)
text = escape_custom_split_sequence(text)
return text, smiles_list
class Blip2OPT(Blip2Base):
"""
BLIP2 first-stage model with Q-former and ViT.
Supported model types:
- pretrained: pretrained model with vit-g
- pretrain_vitL: pretrained model with vit-large
- coco: fintuned model on coco
Usage:
>>> from lavis.models import load_model
>>> model = load_model("blip2", "pretrain")
"""
def __init__(
self,
bert_name,
gin_num_layers,
gin_hidden_dim,
gin_drop_ratio,
tune_gnn=False,
tune_qformer=False,
num_query_token=32,
cross_attention_freq=2,
llm_tune='freeze',
peft_dir='',
opt_model="facebook/galactica-1.3b",
prompt="",
args=None,
):
super().__init__()
self.args = args
self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio)
self.tune_gnn = tune_gnn
self.tune_qformer = tune_qformer
if not tune_gnn:
for name, param in self.graph_encoder.named_parameters():
param.requires_grad = False
self.graph_encoder = self.graph_encoder.eval()
self.graph_encoder.train = disabled_train
logging.info("freeze graph encoder")
else:
logging.info("tune graph encoder")
self.num_query_token = num_query_token
self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq)
if not tune_qformer:
for name, param in self.Qformer.named_parameters():
param.requires_grad = False
self.Qformer = self.Qformer.eval()
self.Qformer.train = disabled_train
self.query_tokens.requires_grad = False
logging.info("freeze qformer encoder")
else:
logging.info("tune qformer encoder")
### remove the unused parameters
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
opt_config_params = {k[len("optconfig_"):]: v for k, v in vars(args).items() if k.startswith("optconfig_")}
config = OPTConfig.from_pretrained(opt_model, **opt_config_params)
## initialize opt model
self.opt_tokenizer = AutoTokenizer.from_pretrained(opt_model, use_fast=False, padding_side='right')
self.opt_tokenizer.add_special_tokens({'pad_token': '<pad>'})
self.opt_tokenizer.add_tokens('<mol>') # molecule placeholder
self.mol_token = '<mol>'
self.opt_tokenizer.mol_token_id = self.opt_tokenizer("<mol>", add_special_tokens=False).input_ids[0]
self.collater = Collater([], [])
if opt_model == 'facebook/galactica-125m':
self.opt_model = OPTForCausalLM.from_pretrained(opt_model, config=config)
else:
if torch.cuda.is_bf16_supported():
self.opt_model = OPTForCausalLM.from_pretrained(opt_model, torch_dtype=torch.bfloat16, config=config)
else:
self.opt_model = OPTForCausalLM.from_pretrained(opt_model, torch_dtype=torch.float16, config=config)
self.opt_model.resize_token_embeddings(len(self.opt_tokenizer)) ## this will cause bug when full fine-tuning the opt model
self.llm_tune = llm_tune
if llm_tune == 'lora':
if peft_dir:
self.opt_model = PeftModel.from_pretrained(self.opt_model, peft_dir, is_trainable=True)
else:
if self.args.peft_config:
peft_config = LoraConfig(**LoraConfig.from_json_file(self.args.peft_config))
else:
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout)
self.peft_config = peft_config
self.opt_model = get_peft_model(self.opt_model, peft_config)
self.opt_model.print_trainable_parameters()
elif llm_tune == 'freeze':
for name, param in self.opt_model.named_parameters():
param.requires_grad = False
elif llm_tune == 'full':
pass
else:
raise NotImplementedError()
## fixme: this is different from the original BLIP2
if args.mode=='pretrain_eval':
self.eos_token_id = self.opt_tokenizer(
"[START_SMILES]\n", add_special_tokens=False
).input_ids
else:
self.eos_token_id = self.opt_tokenizer(
"\n", add_special_tokens=False
).input_ids[0]
self.opt_proj = nn.Linear(
self.Qformer.config.hidden_size, self.opt_model.config.hidden_size
)
## fixme: no prompt yet
self.prompt = prompt
self.rxn_batch_size = args.rxn_batch_size
self.generate_restrict_tokens = args.generate_restrict_tokens
self.train_restrict_tokens = args.train_restrict_tokens
if self.generate_restrict_tokens or self.train_restrict_tokens:
self.bad_words_ids = get_not_allowed_tokens_ids(opt_model)
# prompt_tokens = self.opt_tokenizer(self.prompt, return_tensors="pt")
# self.prompt_length = prompt_tokens.attention_mask.sum(1)
def opt_forward_v2(
self,
inputs_embeds,
attention_mask,
labels,
bad_word_ids=None,
):
output = self.opt_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=labels,
)
logits = output.logits
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if bad_word_ids:
bad_word_ids = torch.tensor(bad_word_ids, device=logits.device, dtype=torch.long)
bad_word_ids = bad_word_ids.squeeze()
logits[:, :, bad_word_ids] = -100
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_logits = shift_logits.view(-1, self.opt_model.config.vocab_size)
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels.view(-1))
return loss
def forward_action(self, batch, use_gragh=True):
# batch unpack
rxn_ids, graphs, text_tokens = batch
if use_gragh:
graph_embeds, graph_masks = self.graph_encoder(graphs)
if not self.tune_gnn:
graph_embeds = graph_embeds.detach()
# graph embedding calculation
graph_embeds = self.ln_graph(graph_embeds, graph_masks)
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=graph_embeds,
encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
return_dict=True,
)
mol_tokens = self.opt_proj(query_output.last_hidden_state) # graph_num x num_query_token x D
else:
del graphs
pad_mask = text_tokens.input_ids == self.opt_tokenizer.pad_token_id
targets = text_tokens.input_ids.masked_fill(pad_mask, -100)
targets = targets.masked_fill(text_tokens.is_mol_token, -100)
targets = targets.masked_fill(text_tokens.token_type_ids == 0, -100)
inputs_embeds = self.opt_model.get_input_embeddings()(text_tokens.input_ids)
if use_gragh:
inputs_embeds[text_tokens.is_mol_token] = mol_tokens.flatten(0, 1) # graph_num x emb_dim
if self.train_restrict_tokens:
loss = self.opt_forward_v2(
inputs_embeds=inputs_embeds,
attention_mask=text_tokens.attention_mask,
labels=targets,
bad_word_ids=self.bad_words_ids,
)
else:
outputs = self.opt_model(
inputs_embeds=inputs_embeds,
attention_mask=text_tokens.attention_mask,
return_dict=True,
labels=targets,
)
loss = outputs.loss
return {"loss": loss}
def forward_abstract(self, batch, use_gragh=True):
# batch unpack
graphs, text_tokens = batch
if use_gragh:
graph_embeds, graph_masks = self.graph_encoder(graphs)
if not self.tune_gnn:
graph_embeds = graph_embeds.detach()
# graph embedding calculation
graph_embeds = self.ln_graph(graph_embeds, graph_masks)
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=graph_embeds,
encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct
return_dict=True,
)
mol_tokens = self.opt_proj(query_output.last_hidden_state) # graph_num x num_query_token x D
else:
del graphs
pad_mask = text_tokens.input_ids == self.opt_tokenizer.pad_token_id
targets = text_tokens.input_ids.masked_fill(pad_mask, -100)
targets = targets.masked_fill(text_tokens.is_mol_token, -100)
inputs_embeds = self.opt_model.get_input_embeddings()(text_tokens.input_ids)
if use_gragh:
inputs_embeds[text_tokens.is_mol_token] = mol_tokens.flatten(0, 1) # graph_num x emb_dim
outputs = self.opt_model(
inputs_embeds=inputs_embeds,
attention_mask=text_tokens.attention_mask,
return_dict=True,
labels=targets,
)
loss = outputs.loss
return {"loss": loss}
@torch.no_grad()
def generate(
self,
samples,
do_sample=False,
num_beams=5,
max_length=128,
min_length=1,
top_p=0.9,
repetition_penalty=1.0,
length_penalty=1.0,
num_captions=1,
temperature=1,
use_graph=True,
):
"""
Args:
samples (dict): A dictionary containing the following keys:
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
num_beams (int): Number of beams for beam search. 1 means no beam search.
max_length (int): The maximum length of the sequence to be generated.
min_length (int): The minimum length of the sequence to be generated.
top_p (float): The cumulative probability for nucleus sampling.
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
num_captions (int): Number of captions to be generated for each image.
Returns:
captions (list): A list of strings of length batch_size * num_captions.
"""
graphs = samples['graphs']
prompt_tokens = samples['prompt_tokens']
# prompt_lens = samples['prompt_lens']
# with self.maybe_autocast():
if use_graph:
graph_embeds, graph_masks = self.graph_encoder(graphs)
graph_embeds = self.ln_graph(graph_embeds)
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=graph_embeds,
encoder_attention_mask=graph_masks,
return_dict=True,
)
mol_tokens = self.opt_proj(query_output.last_hidden_state)
prompt_embeds = self.opt_model.get_input_embeddings()(prompt_tokens.input_ids)
if use_graph:
prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(dtype=prompt_embeds.dtype)
extra_params = {}
if self.generate_restrict_tokens:
extra_params['bad_words_ids'] = self.bad_words_ids
outputs = self.opt_model.generate(
inputs_embeds=prompt_embeds,
attention_mask=prompt_tokens.attention_mask,
do_sample=do_sample,
top_p=top_p,
temperature=temperature,
num_beams=num_beams,
max_length=max_length,
min_length=min_length,
# pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
num_return_sequences=num_captions,
# use_cache=False,
**extra_params
)
output_text = self.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
return output_text