Spaces:
Sleeping
Sleeping
Final test
Browse files- app_caption.py +2 -9
- app_vqa.py +8 -16
- prismer/model/modules/vit.py +11 -42
app_caption.py
CHANGED
@@ -31,20 +31,13 @@ def create_demo():
|
|
31 |
inputs = [image, model_name]
|
32 |
outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
|
33 |
|
34 |
-
# paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
35 |
-
# examples = [[path.as_posix(), 'prismer_base'] for path in paths]
|
36 |
-
# gr.Examples(examples=examples,
|
37 |
-
# inputs=inputs,
|
38 |
-
# outputs=outputs,
|
39 |
-
# fn=model.run_caption,
|
40 |
-
# cache_examples=os.getenv('SYSTEM') == 'spaces')
|
41 |
-
|
42 |
paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
43 |
examples = [[path.as_posix(), 'Prismer-Base'] for path in paths]
|
44 |
gr.Examples(examples=examples,
|
45 |
inputs=inputs,
|
46 |
outputs=outputs,
|
47 |
-
fn=model.run_caption
|
|
|
48 |
|
49 |
run_button.click(fn=model.run_caption, inputs=inputs, outputs=outputs)
|
50 |
|
|
|
31 |
inputs = [image, model_name]
|
32 |
outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
35 |
examples = [[path.as_posix(), 'Prismer-Base'] for path in paths]
|
36 |
gr.Examples(examples=examples,
|
37 |
inputs=inputs,
|
38 |
outputs=outputs,
|
39 |
+
fn=model.run_caption,
|
40 |
+
cache_examples=os.getenv('SYSTEM') == 'spaces')
|
41 |
|
42 |
run_button.click(fn=model.run_caption, inputs=inputs, outputs=outputs)
|
43 |
|
app_vqa.py
CHANGED
@@ -31,26 +31,18 @@ def create_demo():
|
|
31 |
inputs = [image, model_name, question]
|
32 |
outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]
|
33 |
|
34 |
-
# paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
35 |
-
# ex_questions = ['What is the man on the right doing?',
|
36 |
-
# 'What is this person playing?',
|
37 |
-
# 'How many cows in this image?',
|
38 |
-
# 'What is the type of animal in this image?',
|
39 |
-
# 'What toy is it?']
|
40 |
-
#
|
41 |
-
# examples = [[path.as_posix(), 'Prismer-Base', ex_questions[i]] for i, path in enumerate(paths)]
|
42 |
-
# gr.Examples(examples=examples,
|
43 |
-
# inputs=inputs,
|
44 |
-
# outputs=outputs,
|
45 |
-
# fn=model.run_vqa,
|
46 |
-
# cache_examples=os.getenv('SYSTEM') == 'spaces')
|
47 |
-
|
48 |
paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
gr.Examples(examples=examples,
|
51 |
inputs=inputs,
|
52 |
outputs=outputs,
|
53 |
-
fn=model.run_vqa
|
|
|
54 |
|
55 |
run_button.click(fn=model.run_vqa, inputs=inputs, outputs=outputs)
|
56 |
|
|
|
31 |
inputs = [image, model_name, question]
|
32 |
outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
35 |
+
ex_questions = ['What is the man on the left doing?',
|
36 |
+
'What is this person doing?',
|
37 |
+
'How many cows in this image?',
|
38 |
+
'What is the type of animal in this image?',
|
39 |
+
'What toy is it?']
|
40 |
+
examples = [[path.as_posix(), 'Prismer-Base', ex_questions[i]] for i, path in enumerate(paths)]
|
41 |
gr.Examples(examples=examples,
|
42 |
inputs=inputs,
|
43 |
outputs=outputs,
|
44 |
+
fn=model.run_vqa,
|
45 |
+
cache_examples=os.getenv('SYSTEM') == 'spaces')
|
46 |
|
47 |
run_button.click(fn=model.run_vqa, inputs=inputs, outputs=outputs)
|
48 |
|
prismer/model/modules/vit.py
CHANGED
@@ -173,45 +173,17 @@ class VisionTransformer(nn.Module):
|
|
173 |
|
174 |
|
175 |
def load_encoder(name: str, experts: dict, image_resolution: int):
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
# modify keys (we only need Vision Transformer)
|
189 |
-
for key in list(state_dict.keys()):
|
190 |
-
if not key.startswith('visual'):
|
191 |
-
del state_dict[key]
|
192 |
-
|
193 |
-
for key in list(state_dict.keys()):
|
194 |
-
new_key = key.replace('visual.', '')
|
195 |
-
if 'proj' in new_key and 'transformer' not in new_key:
|
196 |
-
del state_dict[key]
|
197 |
-
elif 'conv1' in new_key:
|
198 |
-
new_key_ = new_key.replace('conv1', 'conv1.rgb')
|
199 |
-
state_dict[new_key_] = state_dict.pop(key)
|
200 |
-
elif 'positional_embedding' in new_key:
|
201 |
-
state_dict[new_key] = state_dict.pop(key)[1:]
|
202 |
-
elif 'transformer.resblocks' in new_key:
|
203 |
-
new_key_ = re.sub(".mlp", ".0.mlp", new_key)
|
204 |
-
new_key_ = re.sub(".attn", ".0.attn", new_key_)
|
205 |
-
new_key_ = re.sub(".ln", ".0.ln", new_key_)
|
206 |
-
state_dict[new_key_] = state_dict.pop(key)
|
207 |
-
else:
|
208 |
-
state_dict[new_key] = state_dict.pop(key)
|
209 |
-
|
210 |
-
# load pre-trained weights
|
211 |
-
vision_width = state_dict["conv1.rgb.weight"].shape[0]
|
212 |
-
vision_patch_size = state_dict["conv1.rgb.weight"].shape[-1]
|
213 |
-
vision_layers = len([k for k in state_dict.keys() if k.endswith(".attn.in_proj_weight")])
|
214 |
-
vision_heads = vision_width // 64
|
215 |
|
216 |
ViT = VisionTransformer(input_resolution=image_resolution,
|
217 |
patch_size=vision_patch_size,
|
@@ -219,9 +191,6 @@ def load_encoder(name: str, experts: dict, image_resolution: int):
|
|
219 |
layers=vision_layers,
|
220 |
heads=vision_heads,
|
221 |
experts=experts)
|
222 |
-
|
223 |
-
state_dict['positional_embedding'] = interpolate_pos_embed(state_dict['positional_embedding'], len(ViT.positional_embedding))
|
224 |
-
ViT.load_state_dict(state_dict, strict=False)
|
225 |
return ViT
|
226 |
|
227 |
|
|
|
173 |
|
174 |
|
175 |
def load_encoder(name: str, experts: dict, image_resolution: int):
|
176 |
+
if name == 'ViT-B/16':
|
177 |
+
vision_width = 768
|
178 |
+
vision_patch_size = 16
|
179 |
+
vision_layers = 12
|
180 |
+
vision_heads = 12
|
181 |
+
|
182 |
+
elif name == 'ViT-L/14' or name == 'ViT-L/14@336px':
|
183 |
+
vision_width = 1024
|
184 |
+
vision_patch_size = 14
|
185 |
+
vision_layers = 24
|
186 |
+
vision_heads = 16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
ViT = VisionTransformer(input_resolution=image_resolution,
|
189 |
patch_size=vision_patch_size,
|
|
|
191 |
layers=vision_layers,
|
192 |
heads=vision_heads,
|
193 |
experts=experts)
|
|
|
|
|
|
|
194 |
return ViT
|
195 |
|
196 |
|