Update Dockerfile and app.py for improved functionality
Browse files- .gitignore +2 -1
- Dockerfile +5 -1
- app.py +31 -82
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
.idea/
|
2 |
-
venv/
|
|
|
|
1 |
.idea/
|
2 |
+
venv/
|
3 |
+
weights/
|
Dockerfile
CHANGED
@@ -31,12 +31,16 @@ WORKDIR $HOME/app
|
|
31 |
RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
|
32 |
|
33 |
# Install dependencies
|
34 |
-
RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision
|
35 |
|
36 |
# Install SAM and Detectron2
|
37 |
RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
|
38 |
RUN pip install 'git+https://github.com/facebookresearch/detectron2.git'
|
39 |
|
|
|
|
|
|
|
|
|
40 |
COPY app.py .
|
41 |
|
42 |
RUN find $HOME/app
|
|
|
31 |
RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
|
32 |
|
33 |
# Install dependencies
|
34 |
+
RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision pillow
|
35 |
|
36 |
# Install SAM and Detectron2
|
37 |
RUN pip install 'git+https://github.com/facebookresearch/segment-anything.git'
|
38 |
RUN pip install 'git+https://github.com/facebookresearch/detectron2.git'
|
39 |
|
40 |
+
# Download weights
|
41 |
+
RUN mkdir -p $HOME/app/weigths
|
42 |
+
RUN wget -c -O $HOME/app/weigths/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
43 |
+
|
44 |
COPY app.py .
|
45 |
|
46 |
RUN find $HOME/app
|
app.py
CHANGED
@@ -1,98 +1,47 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
from detectron2.data import MetadataCatalog
|
4 |
-
from segment_anything import SamAutomaticMaskGenerator
|
5 |
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
print(metadata)
|
9 |
|
|
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
Sets: source="canvas", tool="sketch"
|
14 |
-
"""
|
15 |
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
|
21 |
-
def preprocess(self, x):
|
22 |
-
return super().preprocess(x)
|
23 |
|
|
|
|
|
24 |
|
25 |
-
demo = gr.Blocks()
|
26 |
-
image = ImageMask(
|
27 |
-
label="Input",
|
28 |
-
type="pil",
|
29 |
-
brush_radius=20.0,
|
30 |
-
brush_color="#FFFFFF")
|
31 |
-
slider = gr.Slider(
|
32 |
-
minimum=1,
|
33 |
-
maximum=3,
|
34 |
-
value=2,
|
35 |
-
label="Granularity",
|
36 |
-
info="Choose in [1, 1.5), [1.5, 2.5), [2.5, 3] for [seem, semantic-sam (multi-level), sam]")
|
37 |
-
mode = gr.Radio(
|
38 |
-
choices=['Automatic', 'Interactive', ],
|
39 |
-
value='Automatic',
|
40 |
-
label="Segmentation Mode")
|
41 |
-
image_out = gr.Image(label="Auto generation", type="pil")
|
42 |
-
slider_alpha = gr.Slider(
|
43 |
-
minimum=0,
|
44 |
-
maximum=1,
|
45 |
-
value=0.1,
|
46 |
-
label="Mask Alpha",
|
47 |
-
info="Choose in [0, 1]")
|
48 |
-
label_mode = gr.Radio(
|
49 |
-
choices=['Number', 'Alphabet'],
|
50 |
-
value='Number',
|
51 |
-
label="Mark Mode")
|
52 |
-
anno_mode = gr.CheckboxGroup(
|
53 |
-
choices=["Mask", "Box", "Mark"],
|
54 |
-
value=['Mask', 'Mark'],
|
55 |
-
label="Annotation Mode")
|
56 |
-
runBtn = gr.Button("Run")
|
57 |
|
58 |
-
|
59 |
-
|
|
|
60 |
|
61 |
-
with demo:
|
62 |
-
gr.Markdown(
|
63 |
-
gr.Markdown("<h3 style='text-align: center; margin-bottom: 1rem'>project: <a href='https://som-gpt4v.github.io/'>link</a>, arXiv: <a href='https://arxiv.org/abs/2310.11441'>link</a>, code: <a href='https://github.com/microsoft/SoM'>link</a></h3>")
|
64 |
-
gr.Markdown(f"<h3 style='margin-bottom: 1rem'>{description}</h3>")
|
65 |
with gr.Row():
|
66 |
with gr.Column():
|
67 |
-
|
68 |
-
slider.render()
|
69 |
-
with gr.Row():
|
70 |
-
mode.render()
|
71 |
-
anno_mode.render()
|
72 |
-
with gr.Row():
|
73 |
-
slider_alpha.render()
|
74 |
-
label_mode.render()
|
75 |
with gr.Column():
|
76 |
-
|
77 |
-
|
78 |
-
# with gr.Row():
|
79 |
-
# example = gr.Examples(
|
80 |
-
# examples=[
|
81 |
-
# ["examples/ironing_man.jpg"],
|
82 |
-
# ],
|
83 |
-
# inputs=image,
|
84 |
-
# cache_examples=False,
|
85 |
-
# )
|
86 |
-
# example = gr.Examples(
|
87 |
-
# examples=[
|
88 |
-
# ["examples/ironing_man_som.png"],
|
89 |
-
# ],
|
90 |
-
# inputs=image,
|
91 |
-
# cache_examples=False,
|
92 |
-
# label='Marked Examples',
|
93 |
-
# )
|
94 |
|
95 |
-
|
96 |
-
# outputs = image_out)
|
97 |
|
98 |
-
demo.queue().launch()
|
|
|
1 |
+
import torch
|
|
|
|
|
|
|
2 |
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import supervision as sv
|
6 |
|
7 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
|
|
8 |
|
9 |
+
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
10 |
|
11 |
+
SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
|
12 |
+
SAM_MODEL_TYPE = "vit_h"
|
|
|
|
|
13 |
|
14 |
+
MARKDOWN = """
|
15 |
+
<h1 style='text-align: center'>
|
16 |
+
<img
|
17 |
+
src='https://som-gpt4v.github.io/website/img/som_logo.png'
|
18 |
+
style='height:50px; display:inline-block'
|
19 |
+
/>
|
20 |
+
Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
|
21 |
+
</h1>
|
22 |
+
"""
|
23 |
|
24 |
+
sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
|
25 |
+
mask_generator = SamAutomaticMaskGenerator(sam)
|
26 |
|
|
|
|
|
27 |
|
28 |
+
def inference(image: np.ndarray) -> np.ndarray:
|
29 |
+
return image
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
image_input = gr.Image(label="Input", type="numpy")
|
33 |
+
image_output = gr.Image(label="SoM Visual Prompt", type="numpy", height=512)
|
34 |
+
run_button = gr.Button("Run")
|
35 |
|
36 |
+
with gr.Blocks() as demo:
|
37 |
+
gr.Markdown(MARKDOWN)
|
|
|
|
|
38 |
with gr.Row():
|
39 |
with gr.Column():
|
40 |
+
image_input.render()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
with gr.Column():
|
42 |
+
image_output.render()
|
43 |
+
run_button.render()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
run_button.click(inference, inputs=[image_input], outputs=image_output)
|
|
|
46 |
|
47 |
+
demo.queue().launch(debug=False, show_error=True)
|