Spaces:
Sleeping
Sleeping
Add VQA
Browse files- app_caption.py +1 -1
- app_vqa.py +6 -5
- prismer_model.py +28 -16
app_caption.py
CHANGED
@@ -18,7 +18,7 @@ def create_demo():
|
|
18 |
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base, Prismer-Large'], value='Prismer-Base')
|
19 |
run_button = gr.Button('Run')
|
20 |
with gr.Column(scale=1.5):
|
21 |
-
caption = gr.Text(label='
|
22 |
with gr.Row():
|
23 |
depth = gr.Image(label='Depth')
|
24 |
edge = gr.Image(label='Edge')
|
|
|
18 |
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base, Prismer-Large'], value='Prismer-Base')
|
19 |
run_button = gr.Button('Run')
|
20 |
with gr.Column(scale=1.5):
|
21 |
+
caption = gr.Text(label='Model Prediction')
|
22 |
with gr.Row():
|
23 |
depth = gr.Image(label='Depth')
|
24 |
edge = gr.Image(label='Edge')
|
app_vqa.py
CHANGED
@@ -16,9 +16,10 @@ def create_demo():
|
|
16 |
with gr.Column():
|
17 |
image = gr.Image(label='Input', type='filepath')
|
18 |
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base', 'Prismer-Large'], value='Prismer-Base')
|
|
|
19 |
run_button = gr.Button('Run')
|
20 |
with gr.Column(scale=1.5):
|
21 |
-
|
22 |
with gr.Row():
|
23 |
depth = gr.Image(label='Depth')
|
24 |
edge = gr.Image(label='Edge')
|
@@ -28,8 +29,8 @@ def create_demo():
|
|
28 |
object_detection = gr.Image(label='Object Detection')
|
29 |
ocr = gr.Image(label='OCR Detection')
|
30 |
|
31 |
-
inputs = [image, model_name]
|
32 |
-
outputs = [
|
33 |
|
34 |
# paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
35 |
# examples = [[path.as_posix(), 'prismer_base'] for path in paths]
|
@@ -44,9 +45,9 @@ def create_demo():
|
|
44 |
gr.Examples(examples=examples,
|
45 |
inputs=inputs,
|
46 |
outputs=outputs,
|
47 |
-
fn=model.
|
48 |
|
49 |
-
run_button.click(fn=model.
|
50 |
|
51 |
|
52 |
if __name__ == '__main__':
|
|
|
16 |
with gr.Column():
|
17 |
image = gr.Image(label='Input', type='filepath')
|
18 |
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base', 'Prismer-Large'], value='Prismer-Base')
|
19 |
+
question = gr.Text(label='Question')
|
20 |
run_button = gr.Button('Run')
|
21 |
with gr.Column(scale=1.5):
|
22 |
+
answer = gr.Text(label='Model Prediction')
|
23 |
with gr.Row():
|
24 |
depth = gr.Image(label='Depth')
|
25 |
edge = gr.Image(label='Edge')
|
|
|
29 |
object_detection = gr.Image(label='Object Detection')
|
30 |
ocr = gr.Image(label='OCR Detection')
|
31 |
|
32 |
+
inputs = [image, model_name, question]
|
33 |
+
outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]
|
34 |
|
35 |
# paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
36 |
# examples = [[path.as_posix(), 'prismer_base'] for path in paths]
|
|
|
45 |
gr.Examples(examples=examples,
|
46 |
inputs=inputs,
|
47 |
outputs=outputs,
|
48 |
+
fn=model.run_vqa_model)
|
49 |
|
50 |
+
run_button.click(fn=model.run_vqa_model, inputs=inputs, outputs=outputs)
|
51 |
|
52 |
|
53 |
if __name__ == '__main__':
|
prismer_model.py
CHANGED
@@ -16,7 +16,9 @@ submodule_dir = repo_dir / 'prismer'
|
|
16 |
sys.path.insert(0, submodule_dir.as_posix())
|
17 |
|
18 |
from dataset import create_dataset, create_loader
|
|
|
19 |
from model.prismer_caption import PrismerCaption
|
|
|
20 |
|
21 |
|
22 |
def download_models() -> None:
|
@@ -73,6 +75,11 @@ class Model:
|
|
73 |
if exp_name == self.exp_name:
|
74 |
return
|
75 |
|
|
|
|
|
|
|
|
|
|
|
76 |
if self.mode == 'caption':
|
77 |
config = {
|
78 |
'dataset': 'demo',
|
@@ -80,13 +87,12 @@ class Model:
|
|
80 |
'label_path': 'prismer/helpers/labels',
|
81 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
82 |
'image_resolution': 480,
|
83 |
-
'prismer_model':
|
84 |
'freeze': 'freeze_vision',
|
85 |
-
'prefix': '
|
86 |
}
|
87 |
-
|
88 |
model = PrismerCaption(config)
|
89 |
-
state_dict = torch.load(f'prismer/logging/
|
90 |
|
91 |
elif self.mode == 'vqa':
|
92 |
config = {
|
@@ -95,13 +101,12 @@ class Model:
|
|
95 |
'label_path': 'prismer/helpers/labels',
|
96 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
97 |
'image_resolution': 480,
|
98 |
-
'prismer_model':
|
99 |
'freeze': 'freeze_vision',
|
100 |
-
'prefix': 'A picture of',
|
101 |
}
|
102 |
|
103 |
-
model =
|
104 |
-
state_dict = torch.load(f'prismer/logging/
|
105 |
|
106 |
model.load_state_dict(state_dict)
|
107 |
model.eval()
|
@@ -131,14 +136,21 @@ class Model:
|
|
131 |
return caption, *out_paths
|
132 |
|
133 |
@torch.inference_mode()
|
134 |
-
def run_vqa_model(self, exp_name: str) -> str:
|
135 |
self.set_model(exp_name)
|
136 |
-
_, test_dataset = create_dataset('
|
137 |
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
|
138 |
experts, _ = next(iter(test_loader))
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
sys.path.insert(0, submodule_dir.as_posix())
|
17 |
|
18 |
from dataset import create_dataset, create_loader
|
19 |
+
from dataset.utils import pre_question
|
20 |
from model.prismer_caption import PrismerCaption
|
21 |
+
from model.prismer_vqa import PrismerVQA
|
22 |
|
23 |
|
24 |
def download_models() -> None:
|
|
|
75 |
if exp_name == self.exp_name:
|
76 |
return
|
77 |
|
78 |
+
if self.exp_name == 'Prismer-Base':
|
79 |
+
model_name = 'prismer_base'
|
80 |
+
elif self.exp_name == 'Prismer-Large':
|
81 |
+
model_name = 'prismer_large'
|
82 |
+
|
83 |
if self.mode == 'caption':
|
84 |
config = {
|
85 |
'dataset': 'demo',
|
|
|
87 |
'label_path': 'prismer/helpers/labels',
|
88 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
89 |
'image_resolution': 480,
|
90 |
+
'prismer_model': model_name,
|
91 |
'freeze': 'freeze_vision',
|
92 |
+
'prefix': '',
|
93 |
}
|
|
|
94 |
model = PrismerCaption(config)
|
95 |
+
state_dict = torch.load(f'prismer/logging/pretrain_{model_name}/pytorch_model.bin', map_location='cuda:0')
|
96 |
|
97 |
elif self.mode == 'vqa':
|
98 |
config = {
|
|
|
101 |
'label_path': 'prismer/helpers/labels',
|
102 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
103 |
'image_resolution': 480,
|
104 |
+
'prismer_model': model_name,
|
105 |
'freeze': 'freeze_vision',
|
|
|
106 |
}
|
107 |
|
108 |
+
model = PrismerVQA(config)
|
109 |
+
state_dict = torch.load(f'prismer/logging/vqa_{model_name}/pytorch_model.bin', map_location='cuda:0')
|
110 |
|
111 |
model.load_state_dict(state_dict)
|
112 |
model.eval()
|
|
|
136 |
return caption, *out_paths
|
137 |
|
138 |
@torch.inference_mode()
|
139 |
+
def run_vqa_model(self, exp_name: str, question: str) -> str:
|
140 |
self.set_model(exp_name)
|
141 |
+
_, test_dataset = create_dataset('caption', self.config)
|
142 |
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
|
143 |
experts, _ = next(iter(test_loader))
|
144 |
+
question = pre_question(question)
|
145 |
+
answer = self.model(experts, question, train=False, inference='generate')
|
146 |
+
answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
|
147 |
+
answer = answer.to(experts['rgb'].device)[0]
|
148 |
+
answer = self.tokenizer.decode(answer, skip_special_tokens=True)
|
149 |
+
answer = answer.capitalize() + '.'
|
150 |
+
return answer
|
151 |
+
|
152 |
+
def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]:
|
153 |
+
out_paths = run_experts(image_path)
|
154 |
+
answer = self.run_vqa_model(model_name, question)
|
155 |
+
label_prettify(image_path, out_paths)
|
156 |
+
return answer, *out_paths
|