shikunl commited on
Commit
64fb58a
1 Parent(s): c83f375

Update structure

Browse files
Files changed (4) hide show
  1. app.py +31 -4
  2. examples/1.jpeg +0 -0
  3. gradio_caption.py +32 -0
  4. gradio_vqa.py +33 -0
app.py CHANGED
@@ -1,7 +1,34 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
 
4
+ from gradio_caption import create_demo as create_caption
5
+ from gradio_vqa import create_demo as create_vqa
6
 
7
+
8
+ css = """
9
+ #img-display-input {
10
+ height: auto;
11
+ max-height: 40vh;
12
+ }
13
+ #img-display-output {
14
+ max-height: 40vh;
15
+ }
16
+ """
17
+
18
+
19
+ description = """
20
+ # Prismer
21
+ The official demo for **Prismer: A Vision-Language Model with An Ensemble of Experts**.
22
+ Please refer to our [project page](https://shikun.io/projects/prismer) or [github](https://github.com/NVlabs/prismer) for more details.
23
+ """
24
+
25
+ with gr.Blocks(css=css) as demo:
26
+ gr.Markdown(description)
27
+ with gr.Tab("Zero-shot Image Captioning"):
28
+ create_caption()
29
+ with gr.Tab("Visual Question Answering"):
30
+ create_vqa()
31
+
32
+
33
+ if __name__ == '__main__':
34
+ demo.queue().launch()
examples/1.jpeg ADDED
gradio_caption.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import tempfile
4
+
5
+
6
+ def predict_depth(model, image):
7
+ depth = model.infer_pil(image)
8
+ return depth
9
+
10
+
11
+ def create_demo():
12
+ with gr.Row():
13
+ with gr.Column(scale=1):
14
+ model_type = gr.Dropdown(["Prismer-Base", "Prismer-Large"], label="Model Size", value="Prismer-Base")
15
+ rgb = gr.Image(label="Input Image", type='pil', elem_id='img-display-input')
16
+ submit = gr.Button("Submit")
17
+ with gr.Column(scale=2):
18
+ pred = gr.Textbox(label="Model Prediction")
19
+ with gr.Row():
20
+ depth = gr.Image(label="Depth", elem_id='img-display-output')
21
+ edge = gr.Image(label="Edge", elem_id='img-display-output')
22
+ normals = gr.Image(label="Normals", elem_id='img-display-output')
23
+ with gr.Row():
24
+ seg = gr.Image(label="Segmentation", elem_id='img-display-output')
25
+ obj_det = gr.Image(label="Object Detection", elem_id='img-display-output')
26
+ ocr_det = gr.Image(label="OCR Detection", elem_id='img-display-output')
27
+
28
+ def on_submit(im, model_type):
29
+ return pred, depth, edge, normals, seg, obj_det, ocr_det
30
+
31
+ submit.click(on_submit, inputs=[rgb, model_type], outputs=[pred, depth, edge, normals, seg, obj_det, ocr_det])
32
+ examples = gr.Examples(examples=["examples/1.jpeg"], inputs=[rgb])
gradio_vqa.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import tempfile
4
+
5
+
6
+ def predict_depth(model, image):
7
+ depth = model.infer_pil(image)
8
+ return depth
9
+
10
+
11
+ def create_demo():
12
+ with gr.Row():
13
+ with gr.Column(scale=1):
14
+ model_type = gr.Dropdown(["Prismer-Base", "Prismer-Large"], label="Model Size", value="Prismer-Base")
15
+ ques = gr.Textbox(label="Question", placeholder="What's the number of this player?")
16
+ rgb = gr.Image(label="Input Image", type='pil', elem_id='img-display-input').style(height="auto")
17
+ submit = gr.Button("Submit")
18
+ with gr.Column(scale=2):
19
+ pred = gr.Textbox(label="Model Prediction")
20
+ with gr.Row():
21
+ depth = gr.Image(label="Depth", elem_id='img-display-output')
22
+ edge = gr.Image(label="Edge", elem_id='img-display-output')
23
+ normals = gr.Image(label="Normals", elem_id='img-display-output')
24
+ with gr.Row():
25
+ seg = gr.Image(label="Segmentation", elem_id='img-display-output')
26
+ obj_det = gr.Image(label="Object Detection", elem_id='img-display-output')
27
+ ocr_det = gr.Image(label="OCR Detection", elem_id='img-display-output')
28
+
29
+ def on_submit(im, q, model_type):
30
+ return pred, depth, edge, normals, seg, obj_det, ocr_det
31
+
32
+ submit.click(on_submit, inputs=[rgb, ques, model_type], outputs=[pred, depth, edge, normals, seg, obj_det, ocr_det])
33
+ examples = gr.Examples(examples=["examples/1.jpeg"], inputs=[rgb])