File size: 5,066 Bytes
51f6859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) OpenMMLab. All rights reserved.
import os
from collections import OrderedDict

import torch
from mmcv import Config
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE

from mmdet.apis import init_detector, inference_detector
from mmdet.datasets import (CocoDataset)
from mmdet.utils import (compat_cfg, replace_cfg_vals, setup_multi_processes,
                         update_data_root)

import gradio as gr

config_dict = OrderedDict([('swin-l-hdetr_sam-vit-b', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py'),
                           ('swin-l-hdetr_sam-vit-l', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py'),
                           ('swin-l-hdetr_sam-vit-h', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py'),
                           ('focalnet-l-dino_sam-vit-b', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py'),
                           ('focalnet-l-dino_sam-vit-l', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py'),
                           (
                           'focalnet-l-dino_sam-vit-h', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py')])


def inference(img, config):
    if img is None:
        return None
    config = config_dict[config]
    cfg = Config.fromfile(config)

    # replace the ${key} with the value of cfg.key
    cfg = replace_cfg_vals(cfg)

    # update data root according to MMDET_DATASETS
    update_data_root(cfg)

    cfg = compat_cfg(cfg)

    # set multi-process settings
    setup_multi_processes(cfg)

    # import modules from plguin/xx, registry will be updated
    if hasattr(cfg, 'plugin'):
        if cfg.plugin:
            import importlib
            if hasattr(cfg, 'plugin_dir'):
                plugin_dir = cfg.plugin_dir
                _module_dir = os.path.dirname(plugin_dir)
                _module_dir = _module_dir.split('/')
                _module_path = _module_dir[0]

                for m in _module_dir[1:]:
                    _module_path = _module_path + '.' + m
                print(_module_path)
                plg_lib = importlib.import_module(_module_path)
            else:
                # import dir is the dirpath for the config file
                _module_dir = os.path.dirname(config)
                _module_dir = _module_dir.split('/')
                _module_path = _module_dir[0]
                for m in _module_dir[1:]:
                    _module_path = _module_path + '.' + m
                # print(_module_path)
                plg_lib = importlib.import_module(_module_path)

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    if IS_CUDA_AVAILABLE or IS_MLU_AVAILABLE:
        device = "cuda"
    else:
        device = "cpu"
    model = init_detector(cfg, None, device=device)
    model.CLASSES = CocoDataset.CLASSES

    results = inference_detector(model, img)
    visualize = model.show_result(
        img,
        results,
        bbox_color=CocoDataset.PALETTE,
        text_color=CocoDataset.PALETTE,
        mask_color=CocoDataset.PALETTE,
        show=False,
        out_file=None,
        score_thr=0.3
    )
    del model
    return visualize


description = """
#  <center>Prompt Segment Anything (zero-shot instance segmentation demo)</center>
Github link: [Link](https://github.com/RockeyCoss/Prompt-Segment-Anything)
You can select the model you want to use from the "Model" dropdown menu and click "Submit" to segment the image you uploaded to the "Input Image" box.
"""


def main():
    with gr.Blocks() as demo:
        gr.Markdown(description)
        with gr.Column():
            with gr.Row():
                with gr.Column():
                    input_img = gr.Image(type="numpy", label="Input Image")
                    model_type = gr.Dropdown(choices=list(config_dict.keys()),
                                             value=list(config_dict.keys())[0],
                                             label='Model',
                                             multiselect=False)
                    with gr.Row():
                        clear_btn = gr.Button(value="Clear")
                        submit_btn = gr.Button(value="Submit")
                output_img = gr.Image(type="numpy", label="Output")
            gr.Examples(
                examples=[["./assets/img1.jpg", "swin-l-hdetr_sam-vit-b"],
                          ["./assets/img2.jpg", "swin-l-hdetr_sam-vit-l"],
                          ["./assets/img3.jpg", "swin-l-hdetr_sam-vit-l"],
                          ["./assets/img4.jpg", "focalnet-l-dino_sam-vit-b"]],
                inputs=[input_img, model_type],
                outputs=output_img,
                fn=inference
            )

        submit_btn.click(inference,
                         inputs=[input_img, model_type],
                         outputs=output_img)
        clear_btn.click(lambda: [None, None], None, [input_img, output_img], queue=False)

    demo.queue()
    demo.launch(share=True)


if __name__ == '__main__':
    main()