Commit
•
71d8fad
1
Parent(s):
9ebba66
Create new file
Browse files
blip.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import warnings
|
9 |
+
warnings.filterwarnings("ignore")
|
10 |
+
|
11 |
+
from vit import VisionTransformer, interpolate_pos_embed
|
12 |
+
from med import BertConfig, BertModel, BertLMHeadModel
|
13 |
+
from transformers import BertTokenizer
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
import os
|
20 |
+
from urllib.parse import urlparse
|
21 |
+
from timm.models.hub import download_cached_file
|
22 |
+
|
23 |
+
class BLIP_Base(nn.Module):
|
24 |
+
def __init__(self,
|
25 |
+
med_config = 'configs/med_config.json',
|
26 |
+
image_size = 224,
|
27 |
+
vit = 'base',
|
28 |
+
vit_grad_ckpt = False,
|
29 |
+
vit_ckpt_layer = 0,
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
34 |
+
image_size (int): input image size
|
35 |
+
vit (str): model size of vision transformer
|
36 |
+
"""
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
40 |
+
self.tokenizer = init_tokenizer()
|
41 |
+
med_config = BertConfig.from_json_file(med_config)
|
42 |
+
med_config.encoder_width = vision_width
|
43 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, image, caption, mode):
|
47 |
+
|
48 |
+
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
49 |
+
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
50 |
+
|
51 |
+
if mode=='image':
|
52 |
+
# return image features
|
53 |
+
image_embeds = self.visual_encoder(image)
|
54 |
+
return image_embeds
|
55 |
+
|
56 |
+
elif mode=='text':
|
57 |
+
# return text features
|
58 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
59 |
+
return_dict = True, mode = 'text')
|
60 |
+
return text_output.last_hidden_state
|
61 |
+
|
62 |
+
elif mode=='multimodal':
|
63 |
+
# return multimodel features
|
64 |
+
image_embeds = self.visual_encoder(image)
|
65 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
66 |
+
|
67 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
68 |
+
output = self.text_encoder(text.input_ids,
|
69 |
+
attention_mask = text.attention_mask,
|
70 |
+
encoder_hidden_states = image_embeds,
|
71 |
+
encoder_attention_mask = image_atts,
|
72 |
+
return_dict = True,
|
73 |
+
)
|
74 |
+
return output.last_hidden_state
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
class BLIP_Decoder(nn.Module):
|
79 |
+
def __init__(self,
|
80 |
+
med_config = 'configs/med_config.json',
|
81 |
+
image_size = 384,
|
82 |
+
vit = 'base',
|
83 |
+
vit_grad_ckpt = False,
|
84 |
+
vit_ckpt_layer = 0,
|
85 |
+
prompt = 'a picture of ',
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
Args:
|
89 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
90 |
+
image_size (int): input image size
|
91 |
+
vit (str): model size of vision transformer
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
96 |
+
self.tokenizer = init_tokenizer()
|
97 |
+
med_config = BertConfig.from_json_file(med_config)
|
98 |
+
med_config.encoder_width = vision_width
|
99 |
+
self.text_decoder = BertLMHeadModel(config=med_config)
|
100 |
+
|
101 |
+
self.prompt = prompt
|
102 |
+
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
103 |
+
|
104 |
+
|
105 |
+
def forward(self, image, caption):
|
106 |
+
|
107 |
+
image_embeds = self.visual_encoder(image)
|
108 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
109 |
+
|
110 |
+
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
111 |
+
|
112 |
+
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
113 |
+
|
114 |
+
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
115 |
+
decoder_targets[:,:self.prompt_length] = -100
|
116 |
+
|
117 |
+
decoder_output = self.text_decoder(text.input_ids,
|
118 |
+
attention_mask = text.attention_mask,
|
119 |
+
encoder_hidden_states = image_embeds,
|
120 |
+
encoder_attention_mask = image_atts,
|
121 |
+
labels = decoder_targets,
|
122 |
+
return_dict = True,
|
123 |
+
)
|
124 |
+
loss_lm = decoder_output.loss
|
125 |
+
|
126 |
+
return loss_lm
|
127 |
+
|
128 |
+
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
129 |
+
image_embeds = self.visual_encoder(image)
|
130 |
+
|
131 |
+
if not sample:
|
132 |
+
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
133 |
+
|
134 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
135 |
+
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
136 |
+
|
137 |
+
prompt = [self.prompt] * image.size(0)
|
138 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
139 |
+
input_ids[:,0] = self.tokenizer.bos_token_id
|
140 |
+
input_ids = input_ids[:, :-1]
|
141 |
+
|
142 |
+
if sample:
|
143 |
+
#nucleus sampling
|
144 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
145 |
+
max_length=max_length,
|
146 |
+
min_length=min_length,
|
147 |
+
do_sample=True,
|
148 |
+
top_p=top_p,
|
149 |
+
num_return_sequences=1,
|
150 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
151 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
152 |
+
repetition_penalty=1.1,
|
153 |
+
**model_kwargs)
|
154 |
+
else:
|
155 |
+
#beam search
|
156 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
157 |
+
max_length=max_length,
|
158 |
+
min_length=min_length,
|
159 |
+
num_beams=num_beams,
|
160 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
161 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
162 |
+
repetition_penalty=repetition_penalty,
|
163 |
+
**model_kwargs)
|
164 |
+
|
165 |
+
captions = []
|
166 |
+
for output in outputs:
|
167 |
+
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
168 |
+
captions.append(caption[len(self.prompt):])
|
169 |
+
return captions
|
170 |
+
|
171 |
+
|
172 |
+
def blip_decoder(pretrained='',**kwargs):
|
173 |
+
model = BLIP_Decoder(**kwargs)
|
174 |
+
if pretrained:
|
175 |
+
model,msg = load_checkpoint(model,pretrained)
|
176 |
+
assert(len(msg.missing_keys)==0)
|
177 |
+
return model
|
178 |
+
|
179 |
+
def blip_feature_extractor(pretrained='',**kwargs):
|
180 |
+
model = BLIP_Base(**kwargs)
|
181 |
+
if pretrained:
|
182 |
+
model,msg = load_checkpoint(model,pretrained)
|
183 |
+
assert(len(msg.missing_keys)==0)
|
184 |
+
return model
|
185 |
+
|
186 |
+
def init_tokenizer():
|
187 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
188 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
189 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
190 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
191 |
+
return tokenizer
|
192 |
+
|
193 |
+
|
194 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
195 |
+
|
196 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
197 |
+
if vit=='base':
|
198 |
+
vision_width = 768
|
199 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
200 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
201 |
+
drop_path_rate=0 or drop_path_rate
|
202 |
+
)
|
203 |
+
elif vit=='large':
|
204 |
+
vision_width = 1024
|
205 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
206 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
207 |
+
drop_path_rate=0.1 or drop_path_rate
|
208 |
+
)
|
209 |
+
return visual_encoder, vision_width
|
210 |
+
|
211 |
+
def is_url(url_or_filename):
|
212 |
+
parsed = urlparse(url_or_filename)
|
213 |
+
return parsed.scheme in ("http", "https")
|
214 |
+
|
215 |
+
def load_checkpoint(model,url_or_filename):
|
216 |
+
if is_url(url_or_filename):
|
217 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
218 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
219 |
+
elif os.path.isfile(url_or_filename):
|
220 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
221 |
+
else:
|
222 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
223 |
+
|
224 |
+
state_dict = checkpoint['model']
|
225 |
+
|
226 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
227 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
228 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
229 |
+
model.visual_encoder_m)
|
230 |
+
for key in model.state_dict().keys():
|
231 |
+
if key in state_dict.keys():
|
232 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
233 |
+
del state_dict[key]
|
234 |
+
|
235 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
236 |
+
print('load checkpoint from %s'%url_or_filename)
|
237 |
+
return model,msg
|