shikunl commited on
Commit
806eb00
1 Parent(s): 5a56ebb

Final test

Browse files
Files changed (3) hide show
  1. app_caption.py +2 -9
  2. app_vqa.py +8 -16
  3. prismer/model/modules/vit.py +11 -42
app_caption.py CHANGED
@@ -31,20 +31,13 @@ def create_demo():
31
  inputs = [image, model_name]
32
  outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
33
 
34
- # paths = sorted(pathlib.Path('prismer/images').glob('*'))
35
- # examples = [[path.as_posix(), 'prismer_base'] for path in paths]
36
- # gr.Examples(examples=examples,
37
- # inputs=inputs,
38
- # outputs=outputs,
39
- # fn=model.run_caption,
40
- # cache_examples=os.getenv('SYSTEM') == 'spaces')
41
-
42
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
43
  examples = [[path.as_posix(), 'Prismer-Base'] for path in paths]
44
  gr.Examples(examples=examples,
45
  inputs=inputs,
46
  outputs=outputs,
47
- fn=model.run_caption)
 
48
 
49
  run_button.click(fn=model.run_caption, inputs=inputs, outputs=outputs)
50
 
 
31
  inputs = [image, model_name]
32
  outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
33
 
 
 
 
 
 
 
 
 
34
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
35
  examples = [[path.as_posix(), 'Prismer-Base'] for path in paths]
36
  gr.Examples(examples=examples,
37
  inputs=inputs,
38
  outputs=outputs,
39
+ fn=model.run_caption,
40
+ cache_examples=os.getenv('SYSTEM') == 'spaces')
41
 
42
  run_button.click(fn=model.run_caption, inputs=inputs, outputs=outputs)
43
 
app_vqa.py CHANGED
@@ -31,26 +31,18 @@ def create_demo():
31
  inputs = [image, model_name, question]
32
  outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]
33
 
34
- # paths = sorted(pathlib.Path('prismer/images').glob('*'))
35
- # ex_questions = ['What is the man on the right doing?',
36
- # 'What is this person playing?',
37
- # 'How many cows in this image?',
38
- # 'What is the type of animal in this image?',
39
- # 'What toy is it?']
40
- #
41
- # examples = [[path.as_posix(), 'Prismer-Base', ex_questions[i]] for i, path in enumerate(paths)]
42
- # gr.Examples(examples=examples,
43
- # inputs=inputs,
44
- # outputs=outputs,
45
- # fn=model.run_vqa,
46
- # cache_examples=os.getenv('SYSTEM') == 'spaces')
47
-
48
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
49
- examples = [[path.as_posix(), 'Prismer-Base'] for path in paths]
 
 
 
 
 
50
  gr.Examples(examples=examples,
51
  inputs=inputs,
52
  outputs=outputs,
53
- fn=model.run_vqa)
 
54
 
55
  run_button.click(fn=model.run_vqa, inputs=inputs, outputs=outputs)
56
 
 
31
  inputs = [image, model_name, question]
32
  outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
35
+ ex_questions = ['What is the man on the left doing?',
36
+ 'What is this person doing?',
37
+ 'How many cows in this image?',
38
+ 'What is the type of animal in this image?',
39
+ 'What toy is it?']
40
+ examples = [[path.as_posix(), 'Prismer-Base', ex_questions[i]] for i, path in enumerate(paths)]
41
  gr.Examples(examples=examples,
42
  inputs=inputs,
43
  outputs=outputs,
44
+ fn=model.run_vqa,
45
+ cache_examples=os.getenv('SYSTEM') == 'spaces')
46
 
47
  run_button.click(fn=model.run_vqa, inputs=inputs, outputs=outputs)
48
 
prismer/model/modules/vit.py CHANGED
@@ -173,45 +173,17 @@ class VisionTransformer(nn.Module):
173
 
174
 
175
  def load_encoder(name: str, experts: dict, image_resolution: int):
176
- # load pre-trained model file
177
- if name in _MODELS:
178
- if name != 'ViT-H/14':
179
- model_path = _download(_MODELS[name], os.path.expanduser("cache/clip"))
180
- model = torch.jit.load(model_path, map_location="cpu")
181
- state_dict = model.state_dict()
182
- else:
183
- model_path = hf_hub_download(_MODELS[name], 'open_clip_pytorch_model.bin', revision=None, cache_dir="cache/clip")
184
- state_dict = torch.load(model_path, map_location="cpu")
185
- else:
186
- raise RuntimeError(f"Model {name} not found")
187
-
188
- # modify keys (we only need Vision Transformer)
189
- for key in list(state_dict.keys()):
190
- if not key.startswith('visual'):
191
- del state_dict[key]
192
-
193
- for key in list(state_dict.keys()):
194
- new_key = key.replace('visual.', '')
195
- if 'proj' in new_key and 'transformer' not in new_key:
196
- del state_dict[key]
197
- elif 'conv1' in new_key:
198
- new_key_ = new_key.replace('conv1', 'conv1.rgb')
199
- state_dict[new_key_] = state_dict.pop(key)
200
- elif 'positional_embedding' in new_key:
201
- state_dict[new_key] = state_dict.pop(key)[1:]
202
- elif 'transformer.resblocks' in new_key:
203
- new_key_ = re.sub(".mlp", ".0.mlp", new_key)
204
- new_key_ = re.sub(".attn", ".0.attn", new_key_)
205
- new_key_ = re.sub(".ln", ".0.ln", new_key_)
206
- state_dict[new_key_] = state_dict.pop(key)
207
- else:
208
- state_dict[new_key] = state_dict.pop(key)
209
-
210
- # load pre-trained weights
211
- vision_width = state_dict["conv1.rgb.weight"].shape[0]
212
- vision_patch_size = state_dict["conv1.rgb.weight"].shape[-1]
213
- vision_layers = len([k for k in state_dict.keys() if k.endswith(".attn.in_proj_weight")])
214
- vision_heads = vision_width // 64
215
 
216
  ViT = VisionTransformer(input_resolution=image_resolution,
217
  patch_size=vision_patch_size,
@@ -219,9 +191,6 @@ def load_encoder(name: str, experts: dict, image_resolution: int):
219
  layers=vision_layers,
220
  heads=vision_heads,
221
  experts=experts)
222
-
223
- state_dict['positional_embedding'] = interpolate_pos_embed(state_dict['positional_embedding'], len(ViT.positional_embedding))
224
- ViT.load_state_dict(state_dict, strict=False)
225
  return ViT
226
 
227
 
 
173
 
174
 
175
  def load_encoder(name: str, experts: dict, image_resolution: int):
176
+ if name == 'ViT-B/16':
177
+ vision_width = 768
178
+ vision_patch_size = 16
179
+ vision_layers = 12
180
+ vision_heads = 12
181
+
182
+ elif name == 'ViT-L/14' or name == 'ViT-L/14@336px':
183
+ vision_width = 1024
184
+ vision_patch_size = 14
185
+ vision_layers = 24
186
+ vision_heads = 16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  ViT = VisionTransformer(input_resolution=image_resolution,
189
  patch_size=vision_patch_size,
 
191
  layers=vision_layers,
192
  heads=vision_heads,
193
  experts=experts)
 
 
 
194
  return ViT
195
 
196