shikunl commited on
Commit
19327c9
1 Parent(s): 6eaf487
Files changed (4) hide show
  1. app.py +3 -0
  2. app_vqa.py +54 -0
  3. label_prettify.py +5 -2
  4. prismer_model.py +19 -6
app.py CHANGED
@@ -16,6 +16,7 @@ if os.getenv('SYSTEM') == 'spaces':
16
  dirs_exist_ok=True)
17
 
18
  from app_caption import create_demo as create_demo_caption
 
19
  from prismer_model import build_deformable_conv, download_models
20
 
21
 
@@ -40,5 +41,7 @@ with gr.Blocks() as demo:
40
  with gr.Tabs():
41
  with gr.TabItem('Zero-shot Image Captioning'):
42
  create_demo_caption()
 
 
43
 
44
  demo.queue(api_open=False).launch()
 
16
  dirs_exist_ok=True)
17
 
18
  from app_caption import create_demo as create_demo_caption
19
+ from app_vqa import create_demo as create_demo_vqa
20
  from prismer_model import build_deformable_conv, download_models
21
 
22
 
 
41
  with gr.Tabs():
42
  with gr.TabItem('Zero-shot Image Captioning'):
43
  create_demo_caption()
44
+ with gr.TabItem('Visual Question Answering'):
45
+ create_demo_vqa()
46
 
47
  demo.queue(api_open=False).launch()
app_vqa.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import pathlib
7
+ import gradio as gr
8
+
9
+ from prismer_model import Model
10
+
11
+
12
+ def create_demo():
13
+ model = Model()
14
+ model.mode = 'vqa'
15
+ with gr.Row():
16
+ with gr.Column():
17
+ image = gr.Image(label='Input', type='filepath')
18
+ model_name = gr.Dropdown(label='Model', choices=['Prismer-Base', 'Prismer-Large'], value='Prismer-Base')
19
+ run_button = gr.Button('Run')
20
+ with gr.Column(scale=1.5):
21
+ caption = gr.Text(label='Caption')
22
+ with gr.Row():
23
+ depth = gr.Image(label='Depth')
24
+ edge = gr.Image(label='Edge')
25
+ normals = gr.Image(label='Normals')
26
+ with gr.Row():
27
+ segmentation = gr.Image(label='Segmentation')
28
+ object_detection = gr.Image(label='Object Detection')
29
+ ocr = gr.Image(label='OCR Detection')
30
+
31
+ inputs = [image, model_name]
32
+ outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
33
+
34
+ # paths = sorted(pathlib.Path('prismer/images').glob('*'))
35
+ # examples = [[path.as_posix(), 'prismer_base'] for path in paths]
36
+ # gr.Examples(examples=examples,
37
+ # inputs=inputs,
38
+ # outputs=outputs,
39
+ # fn=model.run_caption,
40
+ # cache_examples=os.getenv('SYSTEM') == 'spaces')
41
+
42
+ paths = sorted(pathlib.Path('prismer/images').glob('*'))
43
+ examples = [[path.as_posix(), 'Prismer-Base'] for path in paths]
44
+ gr.Examples(examples=examples,
45
+ inputs=inputs,
46
+ outputs=outputs,
47
+ fn=model.run_caption)
48
+
49
+ run_button.click(fn=model.run_caption, inputs=inputs, outputs=outputs)
50
+
51
+
52
+ if __name__ == '__main__':
53
+ demo = create_demo()
54
+ demo.queue().launch()
label_prettify.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import json
 
3
  import torch
4
  import matplotlib.pyplot as plt
5
  import matplotlib
@@ -65,7 +66,8 @@ def seg_prettify(rgb_path, file_name):
65
 
66
  for i in np.unique(seg_labels):
67
  obj_idx_all = np.where(seg_labels == i)
68
- x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
 
69
  obj_name = coco_label_map[int(i * 255)]
70
  obj_name = obj_name.split(',')[0]
71
  if islight(seg_map[int(y), int(x)]):
@@ -105,8 +107,9 @@ def ocr_detection_prettify(rgb_path, file_name):
105
 
106
  x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
107
  plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
108
-
109
  plt.axis('off')
 
 
110
  plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
111
  plt.close()
112
 
 
1
  import os
2
  import json
3
+ import random
4
  import torch
5
  import matplotlib.pyplot as plt
6
  import matplotlib
 
66
 
67
  for i in np.unique(seg_labels):
68
  obj_idx_all = np.where(seg_labels == i)
69
+ obj_idx = random.randint(0, len(obj_idx_all[0]))
70
+ x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx]
71
  obj_name = coco_label_map[int(i * 255)]
72
  obj_name = obj_name.split(',')[0]
73
  if islight(seg_map[int(y), int(x)]):
 
107
 
108
  x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
109
  plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
 
110
  plt.axis('off')
111
+
112
+ os.makedirs(os.path.dirname(file_name), exist_ok=True)
113
  plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
114
  plt.close()
115
 
prismer_model.py CHANGED
@@ -24,10 +24,10 @@ def download_models() -> None:
24
  subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer')
25
 
26
  model_names = [
27
- # 'vqa_prismer_base',
28
- # 'vqa_prismer_large',
29
- 'caption_prismer_base',
30
- 'caption_prismer_large',
31
  ]
32
  for model_name in model_names:
33
  if pathlib.Path(f'prismer/logging/{model_name}').exists():
@@ -126,6 +126,19 @@ class Model:
126
 
127
  def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
128
  out_paths = run_experts(image_path)
129
- # caption = self.run_caption_model(model_name)
130
  label_prettify(image_path, out_paths)
131
- return None, *out_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer')
25
 
26
  model_names = [
27
+ 'vqa_prismer_base',
28
+ 'vqa_prismer_large',
29
+ 'pretrain_prismer_base',
30
+ 'pretrain_prismer_large',
31
  ]
32
  for model_name in model_names:
33
  if pathlib.Path(f'prismer/logging/{model_name}').exists():
 
126
 
127
  def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
128
  out_paths = run_experts(image_path)
129
+ caption = self.run_caption_model(model_name)
130
  label_prettify(image_path, out_paths)
131
+ return caption, *out_paths
132
+
133
+ @torch.inference_mode()
134
+ def run_vqa_model(self, exp_name: str) -> str:
135
+ self.set_model(exp_name)
136
+ _, test_dataset = create_dataset('vqa', self.config)
137
+ test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
138
+ experts, _ = next(iter(test_loader))
139
+ captions = self.model(experts, train=False, prefix=self.config['prefix'])
140
+ captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
141
+ caption = captions.to(experts['rgb'].device)[0]
142
+ caption = self.tokenizer.decode(caption, skip_special_tokens=True)
143
+ caption = caption.capitalize() + '.'
144
+ return caption