shikunl commited on
Commit
0312353
1 Parent(s): b734d92

Test experts

Browse files
Files changed (2) hide show
  1. app_caption.py +11 -15
  2. prismer_model.py +30 -87
app_caption.py CHANGED
@@ -4,7 +4,6 @@ from __future__ import annotations
4
 
5
  import os
6
  import pathlib
7
-
8
  import gradio as gr
9
 
10
  from prismer_model import Model
@@ -16,9 +15,7 @@ def create_demo():
16
  with gr.Row():
17
  with gr.Column():
18
  image = gr.Image(label='Input', type='filepath')
19
- model_name = gr.Dropdown(label='Model',
20
- choices=['prismer_base'],
21
- value='prismer_base')
22
  run_button = gr.Button('Run')
23
  with gr.Column(scale=1.5):
24
  caption = gr.Text(label='Caption')
@@ -32,23 +29,22 @@ def create_demo():
32
  ocr = gr.Image(label='OCR Detection')
33
 
34
  inputs = [image, model_name]
35
- outputs = [
36
- caption,
37
- depth,
38
- edge,
39
- normals,
40
- segmentation,
41
- object_detection,
42
- ocr,
43
- ]
44
 
45
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
46
  examples = [[path.as_posix(), 'prismer_base'] for path in paths]
47
  gr.Examples(examples=examples,
48
  inputs=inputs,
49
  outputs=outputs,
50
- fn=model.run_caption,
51
- cache_examples=os.getenv('SYSTEM') == 'spaces')
52
 
53
  run_button.click(fn=model.run_caption, inputs=inputs, outputs=outputs)
54
 
 
4
 
5
  import os
6
  import pathlib
 
7
  import gradio as gr
8
 
9
  from prismer_model import Model
 
15
  with gr.Row():
16
  with gr.Column():
17
  image = gr.Image(label='Input', type='filepath')
18
+ model_name = gr.Dropdown(label='Model', choices=['prismer_base'], value='prismer_base')
 
 
19
  run_button = gr.Button('Run')
20
  with gr.Column(scale=1.5):
21
  caption = gr.Text(label='Caption')
 
29
  ocr = gr.Image(label='OCR Detection')
30
 
31
  inputs = [image, model_name]
32
+ outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
33
+
34
+ # paths = sorted(pathlib.Path('prismer/images').glob('*'))
35
+ # examples = [[path.as_posix(), 'prismer_base'] for path in paths]
36
+ # gr.Examples(examples=examples,
37
+ # inputs=inputs,
38
+ # outputs=outputs,
39
+ # fn=model.run_caption,
40
+ # cache_examples=os.getenv('SYSTEM') == 'spaces')
41
 
42
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
43
  examples = [[path.as_posix(), 'prismer_base'] for path in paths]
44
  gr.Examples(examples=examples,
45
  inputs=inputs,
46
  outputs=outputs,
47
+ fn=model.run_caption)
 
48
 
49
  run_button.click(fn=model.run_caption, inputs=inputs, outputs=outputs)
50
 
prismer_model.py CHANGED
@@ -20,32 +20,22 @@ from model.prismer_caption import PrismerCaption
20
 
21
  def download_models() -> None:
22
  if not pathlib.Path('prismer/experts/expert_weights/').exists():
23
- subprocess.run(shlex.split(
24
- 'python download_checkpoints.py --download_experts=True'),
25
- cwd='prismer')
26
  model_names = [
27
- 'vqa_prismer_base',
28
- 'vqa_prismer_large',
29
- 'vqa_prismerz_base',
30
- 'vqa_prismerz_large',
31
- 'caption_prismerz_base',
32
- 'caption_prismerz_large',
33
  'caption_prismer_base',
34
  'caption_prismer_large',
35
  ]
36
  for model_name in model_names:
37
  if pathlib.Path(f'prismer/logging/{model_name}').exists():
38
  continue
39
- subprocess.run(shlex.split(
40
- f'python download_checkpoints.py --download_models={model_name}'),
41
- cwd='prismer')
42
 
43
 
44
  def build_deformable_conv() -> None:
45
- subprocess.run(
46
- shlex.split('sh make.sh'),
47
- cwd=
48
- 'prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops')
49
 
50
 
51
  def run_experts(image_path: str) -> tuple[str | None, ...]:
@@ -56,40 +46,18 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
56
  out_path = image_dir / 'image.jpg'
57
  cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
58
 
59
- expert_names = [
60
- 'depth',
61
- 'edge',
62
- 'normal',
63
- 'objdet',
64
- 'ocrdet',
65
- 'segmentation',
66
- ]
67
  for expert_name in expert_names:
68
  env = os.environ.copy()
69
  if 'PYTHONPATH' in env:
70
  env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}'
71
  else:
72
  env['PYTHONPATH'] = submodule_dir.as_posix()
73
- subprocess.run(
74
- shlex.split(f'python experts/generate_{expert_name}.py'),
75
- cwd='prismer',
76
- env=env,
77
- check=True)
78
-
79
- keys = [
80
- 'depth',
81
- 'edge',
82
- 'normal',
83
- 'seg_coco',
84
- 'obj_detection',
85
- 'ocr_detection',
86
- ]
87
- results = [
88
- pathlib.Path('prismer/helpers/labels') / key /
89
- 'helpers/images/image.png' for key in keys
90
- ]
91
- return tuple(path.as_posix() if path.exists() else None
92
- for path in results)
93
 
94
 
95
  class Model:
@@ -102,67 +70,42 @@ class Model:
102
  def set_model(self, exp_name: str) -> None:
103
  if exp_name == self.exp_name:
104
  return
 
105
  config = {
106
- 'dataset':
107
- 'demo',
108
- 'data_path':
109
- 'prismer/helpers',
110
- 'label_path':
111
- 'prismer/helpers/labels',
112
- 'experts': [
113
- 'depth',
114
- 'normal',
115
- 'seg_coco',
116
- 'edge',
117
- 'obj_detection',
118
- 'ocr_detection',
119
- ],
120
- 'image_resolution':
121
- 480,
122
- 'prismer_model':
123
- 'prismer_base',
124
- 'freeze':
125
- 'freeze_vision',
126
- 'prefix':
127
- 'A picture of',
128
  }
 
129
  model = PrismerCaption(config)
130
- state_dict = torch.load(
131
- f'prismer/logging/caption_{exp_name}/pytorch_model.bin',
132
- map_location='cuda:0')
133
  model.load_state_dict(state_dict)
134
  model.eval()
135
- tokenizer = model.tokenizer
136
 
137
  self.config = config
138
  self.model = model
139
- self.tokenizer = tokenizer
140
  self.exp_name = exp_name
141
 
142
  @torch.inference_mode()
143
  def run_caption_model(self, exp_name: str) -> str:
144
  self.set_model(exp_name)
145
-
146
  _, test_dataset = create_dataset('caption', self.config)
147
- test_loader = create_loader(test_dataset,
148
- batch_size=1,
149
- num_workers=4,
150
- train=False)
151
  experts, _ = next(iter(test_loader))
152
- captions = self.model(experts,
153
- train=False,
154
- prefix=self.config['prefix'])
155
- captions = self.tokenizer(captions,
156
- max_length=30,
157
- padding='max_length',
158
- return_tensors='pt').input_ids
159
  caption = captions.to(experts['rgb'].device)[0]
160
  caption = self.tokenizer.decode(caption, skip_special_tokens=True)
161
  caption = caption.capitalize() + '.'
162
  return caption
163
 
164
- def run_caption(self, image_path: str,
165
- model_name: str) -> tuple[str | None, ...]:
166
  out_paths = run_experts(image_path)
167
- caption = self.run_caption_model(model_name)
168
- return caption, *out_paths
 
20
 
21
  def download_models() -> None:
22
  if not pathlib.Path('prismer/experts/expert_weights/').exists():
23
+ subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer')
24
+
 
25
  model_names = [
26
+ # 'vqa_prismer_base',
27
+ # 'vqa_prismer_large',
 
 
 
 
28
  'caption_prismer_base',
29
  'caption_prismer_large',
30
  ]
31
  for model_name in model_names:
32
  if pathlib.Path(f'prismer/logging/{model_name}').exists():
33
  continue
34
+ subprocess.run(shlex.split(f'python download_checkpoints.py --download_models={model_name}'), cwd='prismer')
 
 
35
 
36
 
37
  def build_deformable_conv() -> None:
38
+ subprocess.run(shlex.split('sh make.sh'), cwd='prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops')
 
 
 
39
 
40
 
41
  def run_experts(image_path: str) -> tuple[str | None, ...]:
 
46
  out_path = image_dir / 'image.jpg'
47
  cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
48
 
49
+ expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
 
 
 
 
 
 
 
50
  for expert_name in expert_names:
51
  env = os.environ.copy()
52
  if 'PYTHONPATH' in env:
53
  env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}'
54
  else:
55
  env['PYTHONPATH'] = submodule_dir.as_posix()
56
+ subprocess.run(shlex.split(f'python experts/generate_{expert_name}.py'), cwd='prismer', env=env, check=True)
57
+
58
+ keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
59
+ results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
60
+ return tuple(path.as_posix() if path.exists() else None for path in results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  class Model:
 
70
  def set_model(self, exp_name: str) -> None:
71
  if exp_name == self.exp_name:
72
  return
73
+
74
  config = {
75
+ 'dataset': 'demo',
76
+ 'data_path': 'prismer/helpers',
77
+ 'label_path': 'prismer/helpers/labels',
78
+ 'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
79
+ 'image_resolution': 480,
80
+ 'prismer_model': 'prismer_base',
81
+ 'freeze': 'freeze_vision',
82
+ 'prefix': 'A picture of',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  }
84
+
85
  model = PrismerCaption(config)
86
+ state_dict = torch.load(f'prismer/logging/caption_{exp_name}/pytorch_model.bin', map_location='cuda:0')
 
 
87
  model.load_state_dict(state_dict)
88
  model.eval()
 
89
 
90
  self.config = config
91
  self.model = model
92
+ self.tokenizer = model.tokenizer
93
  self.exp_name = exp_name
94
 
95
  @torch.inference_mode()
96
  def run_caption_model(self, exp_name: str) -> str:
97
  self.set_model(exp_name)
 
98
  _, test_dataset = create_dataset('caption', self.config)
99
+ test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
 
 
 
100
  experts, _ = next(iter(test_loader))
101
+ captions = self.model(experts, train=False, prefix=self.config['prefix'])
102
+ captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
 
 
 
 
 
103
  caption = captions.to(experts['rgb'].device)[0]
104
  caption = self.tokenizer.decode(caption, skip_special_tokens=True)
105
  caption = caption.capitalize() + '.'
106
  return caption
107
 
108
+ def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
 
109
  out_paths = run_experts(image_path)
110
+ # caption = self.run_caption_model(model_name)
111
+ return None, *out_paths