Spaces:
Sleeping
Sleeping
Reset
Browse files- app.py +13 -7
- app_caption.py +3 -13
- prismer_model.py +8 -32
app.py
CHANGED
@@ -11,25 +11,31 @@ import gradio as gr
|
|
11 |
if os.getenv('SYSTEM') == 'spaces':
|
12 |
with open('patch') as f:
|
13 |
subprocess.run('patch -p1'.split(), cwd='prismer', stdin=f)
|
14 |
-
shutil.copytree('prismer/helpers/images',
|
15 |
-
'prismer/images',
|
16 |
-
dirs_exist_ok=True)
|
17 |
|
18 |
from app_caption import create_demo as create_demo_caption
|
19 |
from prismer_model import build_deformable_conv, download_models
|
20 |
|
|
|
21 |
download_models()
|
22 |
build_deformable_conv()
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
|
27 |
-
|
|
|
28 |
|
29 |
with gr.Blocks(css='style.css') as demo:
|
30 |
-
gr.Markdown(
|
31 |
with gr.Tabs():
|
32 |
-
with gr.TabItem('
|
33 |
create_demo_caption()
|
34 |
|
35 |
demo.queue(api_open=False).launch()
|
|
|
11 |
if os.getenv('SYSTEM') == 'spaces':
|
12 |
with open('patch') as f:
|
13 |
subprocess.run('patch -p1'.split(), cwd='prismer', stdin=f)
|
14 |
+
shutil.copytree('prismer/helpers/images', 'prismer/images', dirs_exist_ok=True)
|
|
|
|
|
15 |
|
16 |
from app_caption import create_demo as create_demo_caption
|
17 |
from prismer_model import build_deformable_conv, download_models
|
18 |
|
19 |
+
# Prepare model checkpoints
|
20 |
download_models()
|
21 |
build_deformable_conv()
|
22 |
|
23 |
+
|
24 |
+
# Demo file here
|
25 |
+
description = """
|
26 |
+
# Prismer
|
27 |
+
The official demo for **Prismer: A Vision-Language Model with An Ensemble of Experts**.
|
28 |
+
Please refer to our [project page](https://shikun.io/projects/prismer) or [github](https://github.com/NVlabs/prismer) for more details.
|
29 |
+
"""
|
30 |
|
31 |
if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
|
32 |
+
description += f'For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>'
|
33 |
+
|
34 |
|
35 |
with gr.Blocks(css='style.css') as demo:
|
36 |
+
gr.Markdown(description)
|
37 |
with gr.Tabs():
|
38 |
+
with gr.TabItem('Zero-shot Image Captioning'):
|
39 |
create_demo_caption()
|
40 |
|
41 |
demo.queue(api_open=False).launch()
|
app_caption.py
CHANGED
@@ -15,10 +15,8 @@ def create_demo():
|
|
15 |
|
16 |
with gr.Row():
|
17 |
with gr.Column():
|
18 |
-
image = gr.Image(label='Input', type='filepath')
|
19 |
-
model_name = gr.Dropdown(label='Model',
|
20 |
-
choices=['prismer_base'],
|
21 |
-
value='prismer_base')
|
22 |
run_button = gr.Button('Run')
|
23 |
with gr.Column(scale=1.5):
|
24 |
caption = gr.Text(label='Caption')
|
@@ -32,15 +30,7 @@ def create_demo():
|
|
32 |
ocr = gr.Image(label='OCR Detection')
|
33 |
|
34 |
inputs = [image, model_name]
|
35 |
-
outputs = [
|
36 |
-
caption,
|
37 |
-
depth,
|
38 |
-
edge,
|
39 |
-
normals,
|
40 |
-
segmentation,
|
41 |
-
object_detection,
|
42 |
-
ocr,
|
43 |
-
]
|
44 |
|
45 |
paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
46 |
examples = [[path.as_posix(), 'prismer_base'] for path in paths]
|
|
|
15 |
|
16 |
with gr.Row():
|
17 |
with gr.Column():
|
18 |
+
image = gr.Image(label='Input Image', type='filepath')
|
19 |
+
model_name = gr.Dropdown(label='Model Size', choices=['prismer_base'], value='prismer_base')
|
|
|
|
|
20 |
run_button = gr.Button('Run')
|
21 |
with gr.Column(scale=1.5):
|
22 |
caption = gr.Text(label='Caption')
|
|
|
30 |
ocr = gr.Image(label='OCR Detection')
|
31 |
|
32 |
inputs = [image, model_name]
|
33 |
+
outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
36 |
examples = [[path.as_posix(), 'prismer_base'] for path in paths]
|
prismer_model.py
CHANGED
@@ -20,32 +20,22 @@ from model.prismer_caption import PrismerCaption
|
|
20 |
|
21 |
def download_models() -> None:
|
22 |
if not pathlib.Path('prismer/experts/expert_weights/').exists():
|
23 |
-
subprocess.run(shlex.split(
|
24 |
-
|
25 |
-
cwd='prismer')
|
26 |
model_names = [
|
27 |
-
'vqa_prismer_base',
|
28 |
-
'vqa_prismer_large',
|
29 |
-
'vqa_prismerz_base',
|
30 |
-
'vqa_prismerz_large',
|
31 |
-
'caption_prismerz_base',
|
32 |
-
'caption_prismerz_large',
|
33 |
'caption_prismer_base',
|
34 |
'caption_prismer_large',
|
35 |
]
|
36 |
for model_name in model_names:
|
37 |
if pathlib.Path(f'prismer/logging/{model_name}').exists():
|
38 |
continue
|
39 |
-
subprocess.run(shlex.split(
|
40 |
-
f'python download_checkpoints.py --download_models={model_name}'),
|
41 |
-
cwd='prismer')
|
42 |
|
43 |
|
44 |
def build_deformable_conv() -> None:
|
45 |
-
subprocess.run(
|
46 |
-
shlex.split('sh make.sh'),
|
47 |
-
cwd=
|
48 |
-
'prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops')
|
49 |
|
50 |
|
51 |
def run_experts(image_path: str) -> tuple[str | None, ...]:
|
@@ -56,14 +46,7 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
|
|
56 |
out_path = image_dir / 'image.jpg'
|
57 |
cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
|
58 |
|
59 |
-
expert_names = [
|
60 |
-
'depth',
|
61 |
-
'edge',
|
62 |
-
'normal',
|
63 |
-
'objdet',
|
64 |
-
'ocrdet',
|
65 |
-
'segmentation',
|
66 |
-
]
|
67 |
for expert_name in expert_names:
|
68 |
env = os.environ.copy()
|
69 |
if 'PYTHONPATH' in env:
|
@@ -76,14 +59,7 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
|
|
76 |
env=env,
|
77 |
check=True)
|
78 |
|
79 |
-
keys = [
|
80 |
-
'depth',
|
81 |
-
'edge',
|
82 |
-
'normal',
|
83 |
-
'seg_coco',
|
84 |
-
'obj_detection',
|
85 |
-
'ocr_detection',
|
86 |
-
]
|
87 |
results = [
|
88 |
pathlib.Path('prismer/helpers/labels') / key /
|
89 |
'helpers/images/image.png' for key in keys
|
|
|
20 |
|
21 |
def download_models() -> None:
|
22 |
if not pathlib.Path('prismer/experts/expert_weights/').exists():
|
23 |
+
subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer')
|
24 |
+
|
|
|
25 |
model_names = [
|
26 |
+
# 'vqa_prismer_base',
|
27 |
+
# 'vqa_prismer_large',
|
|
|
|
|
|
|
|
|
28 |
'caption_prismer_base',
|
29 |
'caption_prismer_large',
|
30 |
]
|
31 |
for model_name in model_names:
|
32 |
if pathlib.Path(f'prismer/logging/{model_name}').exists():
|
33 |
continue
|
34 |
+
subprocess.run(shlex.split(f'python download_checkpoints.py --download_models={model_name}'), cwd='prismer')
|
|
|
|
|
35 |
|
36 |
|
37 |
def build_deformable_conv() -> None:
|
38 |
+
subprocess.run(shlex.split('sh make.sh'), cwd='prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops')
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
def run_experts(image_path: str) -> tuple[str | None, ...]:
|
|
|
46 |
out_path = image_dir / 'image.jpg'
|
47 |
cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
|
48 |
|
49 |
+
expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
for expert_name in expert_names:
|
51 |
env = os.environ.copy()
|
52 |
if 'PYTHONPATH' in env:
|
|
|
59 |
env=env,
|
60 |
check=True)
|
61 |
|
62 |
+
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
results = [
|
64 |
pathlib.Path('prismer/helpers/labels') / key /
|
65 |
'helpers/images/image.png' for key in keys
|