|
|
|
import torch |
|
torch.manual_seed(1024) |
|
|
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
|
|
from .configuration_hformer import HformerConfig |
|
from .qformer_src import BertConfig, BertLMHeadModel |
|
|
|
from transformers import BertTokenizerFast as BertTokenizer |
|
|
|
from .configuration_projector import ProjectorConfig |
|
from .modeling_projector import ProjectorModel |
|
from .fuse_modules import BiAttentionBlock |
|
import torch.nn.functional as F |
|
from transformers.activations import ACT2FN |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
ret = super().forward(x) |
|
return ret |
|
|
|
|
|
|
|
|
|
class HformerModel(PreTrainedModel): |
|
_auto_class = 'AutoModel' |
|
config_class = HformerConfig |
|
base_model_prefix = 'model' |
|
supports_gradient_checkpointing = False |
|
|
|
def __init__(self, config) -> None: |
|
super().__init__(config) |
|
self.gradient_checkpointing = False |
|
vision_width = config.visual_hidden_size |
|
num_query_token = config.num_query_token |
|
bert = config.bert |
|
llm_hidden_size = config.llm_hidden_size |
|
cross_attention_freq = config.cross_attention_freq |
|
qformer_pth = config.qformer_pth |
|
|
|
encoder_config = BertConfig.from_pretrained(bert) |
|
encoder_config.encoder_width = vision_width |
|
encoder_config.add_cross_attention = True |
|
encoder_config.cross_attention_freq = cross_attention_freq |
|
encoder_config.query_length = num_query_token |
|
encoder_config.num_hidden_layers = 12 |
|
Qformer = BertLMHeadModel.from_pretrained( |
|
bert, config=encoder_config |
|
) |
|
remove_text = False |
|
if remove_text: |
|
|
|
Qformer.cls = None |
|
Qformer.bert.embeddings.word_embeddings = None |
|
Qformer.bert.embeddings.position_embeddings = None |
|
for layer in Qformer.bert.encoder.layer: |
|
layer.output = None |
|
layer.intermediate = None |
|
|
|
query_tokens = nn.Parameter( |
|
torch.zeros(1, num_query_token, encoder_config.hidden_size) |
|
) |
|
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) |
|
|
|
self.Qformer = Qformer |
|
self.query_tokens = query_tokens |
|
self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias) |
|
self.ln_vision = LayerNorm(encoder_config.encoder_width) |
|
self.ln_llava = LayerNorm(encoder_config.encoder_width) |
|
|
|
tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right') |
|
tokenizer.add_special_tokens({"bos_token": "[DEC]"}) |
|
self.Qformer.resize_token_embeddings(len(tokenizer)) |
|
|
|
if qformer_pth is not None: |
|
pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model'] |
|
print(f'Load Qformer from {qformer_pth}') |
|
self.load_state_dict(pretrained_state_dict, strict=False) |
|
print('Done.') |
|
|
|
projector_config = ProjectorConfig( |
|
visual_hidden_size = config.visual_hidden_size, |
|
llm_hidden_size = config.llm_hidden_size, |
|
projector_depth = 2) |
|
self.connector = ProjectorModel(projector_config) |
|
|
|
d_model = config.llm_hidden_size |
|
dim_feedforward = 1024 |
|
nhead = 8 |
|
fusion_dropout = 0.0 |
|
fusion_droppath = 0.1 |
|
self.fuse = BiAttentionBlock( |
|
v_dim=d_model, |
|
l_dim=d_model, |
|
embed_dim=dim_feedforward, |
|
num_heads=nhead, |
|
dropout=fusion_dropout, |
|
drop_path=fusion_droppath, |
|
) |
|
|
|
modules = [ |
|
nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False), |
|
ACT2FN['gelu'], |
|
nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False) |
|
] |
|
self.ffn = nn.Sequential(*modules) |
|
|
|
def enable_input_require_grads(self): |
|
def make_inputs_require_grad(module, input, output): |
|
if isinstance(output, tuple): |
|
output[0].requires_grad_(True) |
|
output[1].requires_grad_(True) |
|
else: |
|
output.requires_grad_(True) |
|
|
|
self.Qformer.register_forward_hook(make_inputs_require_grad) |
|
self.llm_proj.register_forward_hook(make_inputs_require_grad) |
|
self.ln_vision.register_forward_hook(make_inputs_require_grad) |
|
self.connector.register_forward_hook(make_inputs_require_grad) |
|
self.ffn.register_forward_hook(make_inputs_require_grad) |
|
self.fuse.register_forward_hook(make_inputs_require_grad) |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
exit() |
|
if isinstance(module, ProjectorModel): |
|
module.gradient_checkpointing = value |
|
|
|
def forward(self, x_): |
|
if self.gradient_checkpointing and self.training: |
|
print('Not supprted gradient checkpointing') |
|
|
|
x = self.ln_vision(x_) |
|
query_tokens = self.query_tokens.expand(x.shape[0], -1, -1) |
|
query_output = self.Qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=x, |
|
return_dict=True, |
|
) |
|
q_feat = self.llm_proj(query_output.last_hidden_state) |
|
mlp_outputs = self.connector(x_) |
|
mlp_feat = mlp_outputs |
|
|
|
mlp_feat = mlp_feat + self.fuse(mlp_feat, q_feat) |
|
out = mlp_feat + self.ffn(mlp_feat) |
|
|
|
return out |
|
|
|
|