shikunl commited on
Commit
53b7b42
1 Parent(s): 19327c9
Files changed (3) hide show
  1. app_caption.py +1 -1
  2. app_vqa.py +6 -5
  3. prismer_model.py +28 -16
app_caption.py CHANGED
@@ -18,7 +18,7 @@ def create_demo():
18
  model_name = gr.Dropdown(label='Model', choices=['Prismer-Base, Prismer-Large'], value='Prismer-Base')
19
  run_button = gr.Button('Run')
20
  with gr.Column(scale=1.5):
21
- caption = gr.Text(label='Caption')
22
  with gr.Row():
23
  depth = gr.Image(label='Depth')
24
  edge = gr.Image(label='Edge')
 
18
  model_name = gr.Dropdown(label='Model', choices=['Prismer-Base, Prismer-Large'], value='Prismer-Base')
19
  run_button = gr.Button('Run')
20
  with gr.Column(scale=1.5):
21
+ caption = gr.Text(label='Model Prediction')
22
  with gr.Row():
23
  depth = gr.Image(label='Depth')
24
  edge = gr.Image(label='Edge')
app_vqa.py CHANGED
@@ -16,9 +16,10 @@ def create_demo():
16
  with gr.Column():
17
  image = gr.Image(label='Input', type='filepath')
18
  model_name = gr.Dropdown(label='Model', choices=['Prismer-Base', 'Prismer-Large'], value='Prismer-Base')
 
19
  run_button = gr.Button('Run')
20
  with gr.Column(scale=1.5):
21
- caption = gr.Text(label='Caption')
22
  with gr.Row():
23
  depth = gr.Image(label='Depth')
24
  edge = gr.Image(label='Edge')
@@ -28,8 +29,8 @@ def create_demo():
28
  object_detection = gr.Image(label='Object Detection')
29
  ocr = gr.Image(label='OCR Detection')
30
 
31
- inputs = [image, model_name]
32
- outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
33
 
34
  # paths = sorted(pathlib.Path('prismer/images').glob('*'))
35
  # examples = [[path.as_posix(), 'prismer_base'] for path in paths]
@@ -44,9 +45,9 @@ def create_demo():
44
  gr.Examples(examples=examples,
45
  inputs=inputs,
46
  outputs=outputs,
47
- fn=model.run_caption)
48
 
49
- run_button.click(fn=model.run_caption, inputs=inputs, outputs=outputs)
50
 
51
 
52
  if __name__ == '__main__':
 
16
  with gr.Column():
17
  image = gr.Image(label='Input', type='filepath')
18
  model_name = gr.Dropdown(label='Model', choices=['Prismer-Base', 'Prismer-Large'], value='Prismer-Base')
19
+ question = gr.Text(label='Question')
20
  run_button = gr.Button('Run')
21
  with gr.Column(scale=1.5):
22
+ answer = gr.Text(label='Model Prediction')
23
  with gr.Row():
24
  depth = gr.Image(label='Depth')
25
  edge = gr.Image(label='Edge')
 
29
  object_detection = gr.Image(label='Object Detection')
30
  ocr = gr.Image(label='OCR Detection')
31
 
32
+ inputs = [image, model_name, question]
33
+ outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]
34
 
35
  # paths = sorted(pathlib.Path('prismer/images').glob('*'))
36
  # examples = [[path.as_posix(), 'prismer_base'] for path in paths]
 
45
  gr.Examples(examples=examples,
46
  inputs=inputs,
47
  outputs=outputs,
48
+ fn=model.run_vqa_model)
49
 
50
+ run_button.click(fn=model.run_vqa_model, inputs=inputs, outputs=outputs)
51
 
52
 
53
  if __name__ == '__main__':
prismer_model.py CHANGED
@@ -16,7 +16,9 @@ submodule_dir = repo_dir / 'prismer'
16
  sys.path.insert(0, submodule_dir.as_posix())
17
 
18
  from dataset import create_dataset, create_loader
 
19
  from model.prismer_caption import PrismerCaption
 
20
 
21
 
22
  def download_models() -> None:
@@ -73,6 +75,11 @@ class Model:
73
  if exp_name == self.exp_name:
74
  return
75
 
 
 
 
 
 
76
  if self.mode == 'caption':
77
  config = {
78
  'dataset': 'demo',
@@ -80,13 +87,12 @@ class Model:
80
  'label_path': 'prismer/helpers/labels',
81
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
82
  'image_resolution': 480,
83
- 'prismer_model': 'prismer_base' if self.exp_name == 'Prismer-Base' else 'prismer_large',
84
  'freeze': 'freeze_vision',
85
- 'prefix': 'A picture of',
86
  }
87
-
88
  model = PrismerCaption(config)
89
- state_dict = torch.load(f'prismer/logging/caption_{exp_name}/pytorch_model.bin', map_location='cuda:0')
90
 
91
  elif self.mode == 'vqa':
92
  config = {
@@ -95,13 +101,12 @@ class Model:
95
  'label_path': 'prismer/helpers/labels',
96
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
97
  'image_resolution': 480,
98
- 'prismer_model': 'prismer_base' if self.exp_name == 'Prismer-Base' else 'prismer_large',
99
  'freeze': 'freeze_vision',
100
- 'prefix': 'A picture of',
101
  }
102
 
103
- model = PrismerCaption(config)
104
- state_dict = torch.load(f'prismer/logging/caption_{exp_name}/pytorch_model.bin', map_location='cuda:0')
105
 
106
  model.load_state_dict(state_dict)
107
  model.eval()
@@ -131,14 +136,21 @@ class Model:
131
  return caption, *out_paths
132
 
133
  @torch.inference_mode()
134
- def run_vqa_model(self, exp_name: str) -> str:
135
  self.set_model(exp_name)
136
- _, test_dataset = create_dataset('vqa', self.config)
137
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
138
  experts, _ = next(iter(test_loader))
139
- captions = self.model(experts, train=False, prefix=self.config['prefix'])
140
- captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
141
- caption = captions.to(experts['rgb'].device)[0]
142
- caption = self.tokenizer.decode(caption, skip_special_tokens=True)
143
- caption = caption.capitalize() + '.'
144
- return caption
 
 
 
 
 
 
 
 
16
  sys.path.insert(0, submodule_dir.as_posix())
17
 
18
  from dataset import create_dataset, create_loader
19
+ from dataset.utils import pre_question
20
  from model.prismer_caption import PrismerCaption
21
+ from model.prismer_vqa import PrismerVQA
22
 
23
 
24
  def download_models() -> None:
 
75
  if exp_name == self.exp_name:
76
  return
77
 
78
+ if self.exp_name == 'Prismer-Base':
79
+ model_name = 'prismer_base'
80
+ elif self.exp_name == 'Prismer-Large':
81
+ model_name = 'prismer_large'
82
+
83
  if self.mode == 'caption':
84
  config = {
85
  'dataset': 'demo',
 
87
  'label_path': 'prismer/helpers/labels',
88
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
89
  'image_resolution': 480,
90
+ 'prismer_model': model_name,
91
  'freeze': 'freeze_vision',
92
+ 'prefix': '',
93
  }
 
94
  model = PrismerCaption(config)
95
+ state_dict = torch.load(f'prismer/logging/pretrain_{model_name}/pytorch_model.bin', map_location='cuda:0')
96
 
97
  elif self.mode == 'vqa':
98
  config = {
 
101
  'label_path': 'prismer/helpers/labels',
102
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
103
  'image_resolution': 480,
104
+ 'prismer_model': model_name,
105
  'freeze': 'freeze_vision',
 
106
  }
107
 
108
+ model = PrismerVQA(config)
109
+ state_dict = torch.load(f'prismer/logging/vqa_{model_name}/pytorch_model.bin', map_location='cuda:0')
110
 
111
  model.load_state_dict(state_dict)
112
  model.eval()
 
136
  return caption, *out_paths
137
 
138
  @torch.inference_mode()
139
+ def run_vqa_model(self, exp_name: str, question: str) -> str:
140
  self.set_model(exp_name)
141
+ _, test_dataset = create_dataset('caption', self.config)
142
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
143
  experts, _ = next(iter(test_loader))
144
+ question = pre_question(question)
145
+ answer = self.model(experts, question, train=False, inference='generate')
146
+ answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
147
+ answer = answer.to(experts['rgb'].device)[0]
148
+ answer = self.tokenizer.decode(answer, skip_special_tokens=True)
149
+ answer = answer.capitalize() + '.'
150
+ return answer
151
+
152
+ def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]:
153
+ out_paths = run_experts(image_path)
154
+ answer = self.run_vqa_model(model_name, question)
155
+ label_prettify(image_path, out_paths)
156
+ return answer, *out_paths