Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) 2023, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import logging | |
import torch | |
import torch.nn as nn | |
from torch.cuda.amp import autocast as autocast | |
from peft import get_peft_model, LoraConfig, TaskType, PeftModel | |
from lavis.models.blip2_models.blip2 import disabled_train | |
from model.blip2 import Blip2Base | |
# from model.smiles_t5_captioning | |
from lavis.models.blip2_models.modeling_t5 import T5ForConditionalGeneration | |
from transformers import AutoTokenizer, T5TokenizerFast | |
#, T5ForConditionalGeneration | |
class Blip2T5(Blip2Base): | |
""" | |
BLIP2 first-stage model with Q-former and ViT. | |
Supported model types: | |
- pretrained: pretrained model with vit-g | |
- pretrain_vitL: pretrained model with vit-large | |
- coco: fintuned model on coco | |
Usage: | |
>>> from lavis.models import load_model | |
>>> model = load_model("blip2", "pretrain") | |
""" | |
def __init__( | |
self, | |
bert_name, | |
gin_num_layers, | |
gin_hidden_dim, | |
gin_drop_ratio, | |
tune_gnn=False, | |
num_query_token=32, | |
cross_attention_freq=2, | |
llm_tune='freeze', | |
peft_dir='', | |
opt_model="facebook/galactica-1.3b", | |
prompt="", | |
args=None, | |
): | |
super().__init__() | |
self.args = args | |
self.graph_encoder, self.ln_graph = self.init_graph_encoder(gin_num_layers, gin_hidden_dim, gin_drop_ratio) | |
self.tune_gnn = tune_gnn | |
if not tune_gnn: | |
for name, param in self.graph_encoder.named_parameters(): | |
param.requires_grad = False | |
self.graph_encoder = self.graph_encoder.eval() | |
self.graph_encoder.train = disabled_train | |
logging.info("freeze graph encoder") | |
self.num_query_token = num_query_token | |
self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.graph_encoder.num_features, cross_attention_freq) | |
### remove the unused parameters | |
self.Qformer.cls = None | |
self.Qformer.bert.embeddings.word_embeddings = None | |
self.Qformer.bert.embeddings.position_embeddings = None | |
for layer in self.Qformer.bert.encoder.layer: | |
layer.output = None | |
layer.intermediate = None | |
# assert opt_model == 'laituan245/molt5-large' | |
## initialize opt model | |
# self.opt_tokenizer = AutoTokenizer.from_pretrained(opt_model) | |
self.opt_tokenizer = T5TokenizerFast.from_pretrained(opt_model) | |
self.opt_tokenizer.add_tokens('<mol>') # molecule placeholder | |
self.mol_token = '<mol>' | |
self.opt_tokenizer.mol_token_id = self.opt_tokenizer("<mol>", add_special_tokens=False).input_ids[0] | |
self.opt_model = T5ForConditionalGeneration.from_pretrained(opt_model, torch_dtype=torch.float32) | |
self.opt_model.resize_token_embeddings(len(self.opt_tokenizer)) ## this will cause bug when full fine-tuning the opt model | |
self.llm_tune = llm_tune | |
if llm_tune == 'lora': | |
if peft_dir: | |
self.opt_model = PeftModel.from_pretrained(self.opt_model, peft_dir, is_trainable=True) | |
else: | |
if self.args.peft_config: | |
peft_config = LoraConfig(**LoraConfig.from_json_file(self.args.peft_config)) | |
else: | |
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout) | |
self.peft_config = peft_config | |
self.opt_model = get_peft_model(self.opt_model, peft_config) | |
self.opt_model.print_trainable_parameters() | |
elif llm_tune == 'freeze': | |
for name, param in self.opt_model.named_parameters(): | |
param.requires_grad = False | |
elif llm_tune == 'full': | |
pass | |
else: | |
raise NotImplementedError() | |
## fixme: this is different from the original BLIP2 | |
# self.eos_token_id = self.opt_tokenizer( | |
# "\n", add_special_tokens=False | |
# ).input_ids[0] | |
self.eos_token_id = self.opt_tokenizer( | |
"</s>", add_special_tokens=False | |
).input_ids[0] | |
self.opt_proj = nn.Linear( | |
self.Qformer.config.hidden_size, self.opt_model.config.hidden_size | |
) | |
def forward(self, batch): | |
graphs, prompt_tokens, text_tokens = batch | |
graph_embeds, graph_masks = self.graph_encoder(graphs) | |
if not self.tune_gnn: | |
graph_embeds = graph_embeds.detach() | |
graph_embeds = self.ln_graph(graph_embeds, graph_masks) | |
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=graph_embeds, | |
encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct | |
return_dict=True, | |
) | |
mol_tokens = self.opt_proj(query_output.last_hidden_state) | |
targets = text_tokens.input_ids.masked_fill( | |
text_tokens.input_ids == self.opt_tokenizer.pad_token_id, -100 | |
) | |
with self.maybe_autocast(torch.float32): | |
prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids) | |
prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32) | |
outputs = self.opt_model( | |
inputs_embeds=prompt_embeds, | |
attention_mask=prompt_tokens.attention_mask, | |
decoder_attention_mask=text_tokens.attention_mask, | |
return_dict=True, | |
labels=targets, | |
) | |
loss = outputs.loss | |
return {"loss": loss} | |
def forward_action(self, batch, use_gragh=True): | |
rxn_ids, graphs, prompt_tokens, text_tokens = batch | |
if use_gragh: | |
graph_embeds, graph_masks = self.graph_encoder(graphs) | |
if not self.tune_gnn: | |
graph_embeds = graph_embeds.detach() | |
graph_embeds = self.ln_graph(graph_embeds, graph_masks) | |
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=graph_embeds, | |
encoder_attention_mask=graph_masks, # fixme: check whether this mask is correct | |
return_dict=True, | |
) | |
mol_tokens = self.opt_proj(query_output.last_hidden_state) | |
else: | |
del graphs | |
targets = text_tokens.input_ids.masked_fill( | |
text_tokens.input_ids == self.opt_tokenizer.pad_token_id, -100 | |
) | |
with self.maybe_autocast(torch.float32): | |
prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids) | |
if use_gragh: | |
prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32) | |
outputs = self.opt_model( | |
inputs_embeds=prompt_embeds, | |
attention_mask=prompt_tokens.attention_mask, | |
decoder_attention_mask=text_tokens.attention_mask, | |
return_dict=True, | |
labels=targets, | |
) | |
loss = outputs.loss | |
return {"loss": loss} | |
def generate( | |
self, | |
samples, | |
do_sample=False, | |
num_beams=5, | |
max_length=128, | |
min_length=1, | |
top_p=0.9, | |
repetition_penalty=1.0, | |
length_penalty=1.0, | |
num_captions=1, | |
temperature=1, | |
): | |
""" | |
Args: | |
samples (dict): A dictionary containing the following keys: | |
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) | |
num_beams (int): Number of beams for beam search. 1 means no beam search. | |
max_length (int): The maximum length of the sequence to be generated. | |
min_length (int): The minimum length of the sequence to be generated. | |
top_p (float): The cumulative probability for nucleus sampling. | |
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. | |
num_captions (int): Number of captions to be generated for each image. | |
Returns: | |
captions (list): A list of strings of length batch_size * num_captions. | |
""" | |
graphs = samples['graphs'] | |
prompt_tokens = samples['prompt_tokens'] | |
graph_embeds, graph_masks = self.graph_encoder(graphs) | |
graph_embeds = self.ln_graph(graph_embeds) | |
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=graph_embeds, | |
encoder_attention_mask=graph_masks, | |
return_dict=True, | |
) | |
mol_tokens = self.opt_proj(query_output.last_hidden_state) | |
with self.maybe_autocast(torch.float32): | |
prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids) | |
prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32) | |
# prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids) | |
# prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1) | |
outputs = self.opt_model.generate( | |
inputs_embeds=prompt_embeds, | |
attention_mask=prompt_tokens.attention_mask, | |
do_sample=do_sample, | |
top_p=top_p, | |
temperature=temperature, | |
num_beams=num_beams, | |
max_length=max_length, | |
min_length=min_length, | |
# pad_token_id=self.pad_token_id, | |
eos_token_id=self.eos_token_id, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
num_return_sequences=num_captions, | |
# use_cache=False, | |
) | |
output_text = self.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
output_text = [text.strip() for text in output_text] | |
return output_text | |
def generate_action( | |
self, | |
samples, | |
do_sample=False, | |
num_beams=5, | |
max_length=128, | |
min_length=1, | |
top_p=0.9, | |
repetition_penalty=1.0, | |
length_penalty=1.0, | |
num_captions=1, | |
temperature=1, | |
use_graph=True | |
): | |
graphs = samples['graphs'] | |
prompt_tokens = samples['prompt_tokens'] | |
if use_graph: | |
graph_embeds, graph_masks = self.graph_encoder(graphs) | |
graph_embeds = self.ln_graph(graph_embeds) | |
query_tokens = self.query_tokens.expand(graph_embeds.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=graph_embeds, | |
encoder_attention_mask=graph_masks, | |
return_dict=True, | |
) | |
mol_tokens = self.opt_proj(query_output.last_hidden_state) | |
with self.maybe_autocast(torch.float32): | |
prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids) | |
if use_graph: | |
prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1).to(torch.float32) | |
# prompt_embeds = self.opt_model.encoder.embed_tokens(prompt_tokens.input_ids) | |
# prompt_embeds[prompt_tokens.is_mol_token] = mol_tokens.flatten(0, 1) | |
outputs = self.opt_model.generate( | |
inputs_embeds=prompt_embeds, | |
attention_mask=prompt_tokens.attention_mask, | |
do_sample=do_sample, | |
top_p=top_p, | |
temperature=temperature, | |
num_beams=num_beams, | |
max_length=max_length, | |
min_length=min_length, | |
# pad_token_id=self.pad_token_id, | |
eos_token_id=self.eos_token_id, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
num_return_sequences=num_captions, | |
# use_cache=False, | |
) | |
output_text = self.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
output_text = [text.strip() for text in output_text] | |
return output_text | |