Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) 2023, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import contextlib | |
import logging | |
import os | |
import torch | |
import torch.nn as nn | |
from lavis.common.dist_utils import download_cached_file | |
from lavis.common.utils import is_url | |
from lavis.models.base_model import BaseModel | |
from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel | |
from transformers import BertTokenizer | |
from model.gin_model import GNN | |
class Blip2Base(BaseModel): | |
def init_tokenizer(cls): | |
if True: | |
bert_name = 'allenai/scibert_scivocab_uncased' | |
else: | |
bert_name = 'bert_pretrained/' | |
tokenizer = BertTokenizer.from_pretrained(bert_name) | |
tokenizer.add_special_tokens({"bos_token": "[DEC]"}) | |
return tokenizer | |
def maybe_autocast(self, dtype=torch.float16): | |
# if on cpu, don't use autocast | |
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16 | |
enable_autocast = self.device != torch.device("cpu") | |
if enable_autocast: | |
return torch.cuda.amp.autocast(dtype=dtype) | |
else: | |
return contextlib.nullcontext() | |
def init_Qformer(cls, model_name, num_query_token, graph_width, cross_attention_freq=2): | |
assert model_name == 'scibert' | |
print("bert load scibert") | |
if True: | |
bert_name = 'allenai/scibert_scivocab_uncased' | |
else: | |
bert_name = 'bert_pretrained/' | |
encoder_config = BertConfig.from_pretrained(bert_name) | |
encoder_config.encoder_width = graph_width | |
# insert cross-attention layer every other block | |
encoder_config.add_cross_attention = True | |
encoder_config.cross_attention_freq = cross_attention_freq | |
encoder_config.query_length = num_query_token | |
Qformer = BertLMHeadModel.from_pretrained( | |
bert_name, config=encoder_config | |
) | |
query_tokens = nn.Parameter( | |
torch.zeros(1, num_query_token, encoder_config.hidden_size) | |
) | |
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) | |
return Qformer, query_tokens | |
def init_graph_encoder( | |
cls, gin_num_layers, gin_hidden_dim, gin_drop_ratio): | |
graph_encoder = GNN( | |
num_layer=gin_num_layers, | |
emb_dim=gin_hidden_dim, | |
gnn_type='gin', | |
drop_ratio=gin_drop_ratio, | |
JK='last', | |
) | |
ckpt = torch.load('gin_pretrained/graphcl_80.pth', map_location=torch.device('cpu')) | |
missing_keys, unexpected_keys = graph_encoder.load_state_dict(ckpt, strict=False) | |
if len(missing_keys) or len(unexpected_keys): | |
print(missing_keys) | |
print(unexpected_keys) | |
ln_graph = LayerNorm(graph_encoder.num_features) | |
return graph_encoder, ln_graph | |
def load_from_pretrained(self, url_or_filename): | |
if is_url(url_or_filename): | |
cached_file = download_cached_file( | |
url_or_filename, check_hash=False, progress=True | |
) | |
checkpoint = torch.load(cached_file, map_location="cpu") | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location="cpu") | |
else: | |
raise RuntimeError("checkpoint url or path is invalid") | |
state_dict = checkpoint["model"] | |
msg = self.load_state_dict(state_dict, strict=False) | |
# logging.info("Missing keys {}".format(msg.missing_keys)) | |
logging.info("load checkpoint from %s" % url_or_filename) | |
return msg | |
def disabled_train(self, mode=True): | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
class LayerNorm(nn.LayerNorm): | |
"""Subclass torch's LayerNorm to handle fp16.""" | |
def forward(self, x: torch.Tensor, mask=None): | |
orig_type = x.dtype | |
# ret = super().forward(x.type(torch.float32)) | |
ret = super().forward(x) | |
return ret.type(orig_type) | |