Spaces:
Sleeping
Sleeping
Fix ocr
Browse files- app.py +3 -0
- app_vqa.py +54 -0
- label_prettify.py +5 -2
- prismer_model.py +19 -6
app.py
CHANGED
@@ -16,6 +16,7 @@ if os.getenv('SYSTEM') == 'spaces':
|
|
16 |
dirs_exist_ok=True)
|
17 |
|
18 |
from app_caption import create_demo as create_demo_caption
|
|
|
19 |
from prismer_model import build_deformable_conv, download_models
|
20 |
|
21 |
|
@@ -40,5 +41,7 @@ with gr.Blocks() as demo:
|
|
40 |
with gr.Tabs():
|
41 |
with gr.TabItem('Zero-shot Image Captioning'):
|
42 |
create_demo_caption()
|
|
|
|
|
43 |
|
44 |
demo.queue(api_open=False).launch()
|
|
|
16 |
dirs_exist_ok=True)
|
17 |
|
18 |
from app_caption import create_demo as create_demo_caption
|
19 |
+
from app_vqa import create_demo as create_demo_vqa
|
20 |
from prismer_model import build_deformable_conv, download_models
|
21 |
|
22 |
|
|
|
41 |
with gr.Tabs():
|
42 |
with gr.TabItem('Zero-shot Image Captioning'):
|
43 |
create_demo_caption()
|
44 |
+
with gr.TabItem('Visual Question Answering'):
|
45 |
+
create_demo_vqa()
|
46 |
|
47 |
demo.queue(api_open=False).launch()
|
app_vqa.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import os
|
6 |
+
import pathlib
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from prismer_model import Model
|
10 |
+
|
11 |
+
|
12 |
+
def create_demo():
|
13 |
+
model = Model()
|
14 |
+
model.mode = 'vqa'
|
15 |
+
with gr.Row():
|
16 |
+
with gr.Column():
|
17 |
+
image = gr.Image(label='Input', type='filepath')
|
18 |
+
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base', 'Prismer-Large'], value='Prismer-Base')
|
19 |
+
run_button = gr.Button('Run')
|
20 |
+
with gr.Column(scale=1.5):
|
21 |
+
caption = gr.Text(label='Caption')
|
22 |
+
with gr.Row():
|
23 |
+
depth = gr.Image(label='Depth')
|
24 |
+
edge = gr.Image(label='Edge')
|
25 |
+
normals = gr.Image(label='Normals')
|
26 |
+
with gr.Row():
|
27 |
+
segmentation = gr.Image(label='Segmentation')
|
28 |
+
object_detection = gr.Image(label='Object Detection')
|
29 |
+
ocr = gr.Image(label='OCR Detection')
|
30 |
+
|
31 |
+
inputs = [image, model_name]
|
32 |
+
outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
|
33 |
+
|
34 |
+
# paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
35 |
+
# examples = [[path.as_posix(), 'prismer_base'] for path in paths]
|
36 |
+
# gr.Examples(examples=examples,
|
37 |
+
# inputs=inputs,
|
38 |
+
# outputs=outputs,
|
39 |
+
# fn=model.run_caption,
|
40 |
+
# cache_examples=os.getenv('SYSTEM') == 'spaces')
|
41 |
+
|
42 |
+
paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
43 |
+
examples = [[path.as_posix(), 'Prismer-Base'] for path in paths]
|
44 |
+
gr.Examples(examples=examples,
|
45 |
+
inputs=inputs,
|
46 |
+
outputs=outputs,
|
47 |
+
fn=model.run_caption)
|
48 |
+
|
49 |
+
run_button.click(fn=model.run_caption, inputs=inputs, outputs=outputs)
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
demo = create_demo()
|
54 |
+
demo.queue().launch()
|
label_prettify.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import json
|
|
|
3 |
import torch
|
4 |
import matplotlib.pyplot as plt
|
5 |
import matplotlib
|
@@ -65,7 +66,8 @@ def seg_prettify(rgb_path, file_name):
|
|
65 |
|
66 |
for i in np.unique(seg_labels):
|
67 |
obj_idx_all = np.where(seg_labels == i)
|
68 |
-
|
|
|
69 |
obj_name = coco_label_map[int(i * 255)]
|
70 |
obj_name = obj_name.split(',')[0]
|
71 |
if islight(seg_map[int(y), int(x)]):
|
@@ -105,8 +107,9 @@ def ocr_detection_prettify(rgb_path, file_name):
|
|
105 |
|
106 |
x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
|
107 |
plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
108 |
-
|
109 |
plt.axis('off')
|
|
|
|
|
110 |
plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
|
111 |
plt.close()
|
112 |
|
|
|
1 |
import os
|
2 |
import json
|
3 |
+
import random
|
4 |
import torch
|
5 |
import matplotlib.pyplot as plt
|
6 |
import matplotlib
|
|
|
66 |
|
67 |
for i in np.unique(seg_labels):
|
68 |
obj_idx_all = np.where(seg_labels == i)
|
69 |
+
obj_idx = random.randint(0, len(obj_idx_all[0]))
|
70 |
+
x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx]
|
71 |
obj_name = coco_label_map[int(i * 255)]
|
72 |
obj_name = obj_name.split(',')[0]
|
73 |
if islight(seg_map[int(y), int(x)]):
|
|
|
107 |
|
108 |
x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
|
109 |
plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
|
|
110 |
plt.axis('off')
|
111 |
+
|
112 |
+
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
113 |
plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
|
114 |
plt.close()
|
115 |
|
prismer_model.py
CHANGED
@@ -24,10 +24,10 @@ def download_models() -> None:
|
|
24 |
subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer')
|
25 |
|
26 |
model_names = [
|
27 |
-
|
28 |
-
|
29 |
-
'
|
30 |
-
'
|
31 |
]
|
32 |
for model_name in model_names:
|
33 |
if pathlib.Path(f'prismer/logging/{model_name}').exists():
|
@@ -126,6 +126,19 @@ class Model:
|
|
126 |
|
127 |
def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
|
128 |
out_paths = run_experts(image_path)
|
129 |
-
|
130 |
label_prettify(image_path, out_paths)
|
131 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer')
|
25 |
|
26 |
model_names = [
|
27 |
+
'vqa_prismer_base',
|
28 |
+
'vqa_prismer_large',
|
29 |
+
'pretrain_prismer_base',
|
30 |
+
'pretrain_prismer_large',
|
31 |
]
|
32 |
for model_name in model_names:
|
33 |
if pathlib.Path(f'prismer/logging/{model_name}').exists():
|
|
|
126 |
|
127 |
def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
|
128 |
out_paths = run_experts(image_path)
|
129 |
+
caption = self.run_caption_model(model_name)
|
130 |
label_prettify(image_path, out_paths)
|
131 |
+
return caption, *out_paths
|
132 |
+
|
133 |
+
@torch.inference_mode()
|
134 |
+
def run_vqa_model(self, exp_name: str) -> str:
|
135 |
+
self.set_model(exp_name)
|
136 |
+
_, test_dataset = create_dataset('vqa', self.config)
|
137 |
+
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
|
138 |
+
experts, _ = next(iter(test_loader))
|
139 |
+
captions = self.model(experts, train=False, prefix=self.config['prefix'])
|
140 |
+
captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
|
141 |
+
caption = captions.to(experts['rgb'].device)[0]
|
142 |
+
caption = self.tokenizer.decode(caption, skip_special_tokens=True)
|
143 |
+
caption = caption.capitalize() + '.'
|
144 |
+
return caption
|