File size: 6,812 Bytes
51f6859
 
 
 
 
fa4d18a
a394c57
 
 
 
073441b
 
 
 
fa4d18a
51f6859
 
 
 
 
 
 
 
 
 
a6190ef
 
 
 
 
51f6859
7bbb0c6
51f6859
7bbb0c6
 
 
51f6859
 
 
 
 
e8dd899
51f6859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8c32ac
 
 
51f6859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6190ef
 
 
 
51f6859
 
 
 
 
 
 
 
 
 
 
42fb25d
51f6859
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) OpenMMLab. All rights reserved.
import os
from collections import OrderedDict

import torch

# print(torch.__version__)
# torch_ver, cuda_ver = torch.__version__.split('+')
# os.system('pip list')
# os.system(f'pip install pycocotools==2.0.0 mmdet mmcv-full==1.5.0 -f https://download.openmmlab.com/mmcv/dist/{cuda_ver}/torch1.10.0/index.html --no-cache-dir')
os.system(r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/r50_hdetr.pth')
os.system(r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/swin_t_hdetr.pth')
os.system(r'python tools/convert_ckpt.py ckpt/r50_hdetr.pth ckpt/r50_hdetr.pth')
os.system(r'python tools/convert_ckpt.py ckpt/swin_t_hdetr.pth ckpt/swin_t_hdetr.pth')

from mmcv import Config
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE

from mmdet.apis import init_detector, inference_detector
from mmdet.datasets import (CocoDataset)
from mmdet.utils import (compat_cfg, replace_cfg_vals, setup_multi_processes,
                         update_data_root)

import gradio as gr

config_dict = OrderedDict([('r50-hdetr_sam-vit-b', 'projects/configs/hdetr/r50-hdetr_sam-vit-b.py'),
                           ('r50-hdetr_sam-vit-l', 'projects/configs/hdetr/r50-hdetr_sam-vit-l.py'),
                           ('swin-t-hdetr_sam-vit-b', 'projects/configs/hdetr/swin-t-hdetr_sam-vit-b.py'),
                           ('swin-t-hdetr_sam-vit-l', 'projects/configs/hdetr/swin-t-hdetr_sam-vit-l.py'),
                           ('swin-l-hdetr_sam-vit-b', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py'),
                           ('swin-l-hdetr_sam-vit-l', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py'),
                           # ('swin-l-hdetr_sam-vit-h', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py'),
                           ('focalnet-l-dino_sam-vit-b', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py'),
                           # ('focalnet-l-dino_sam-vit-l', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py'),
                           # ('focalnet-l-dino_sam-vit-h', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py')
                          ])


def inference(img, config):
    if img is None:
        return None
    print(f"config: {config}")
    config = config_dict[config]
    cfg = Config.fromfile(config)

    # replace the ${key} with the value of cfg.key
    cfg = replace_cfg_vals(cfg)

    # update data root according to MMDET_DATASETS
    update_data_root(cfg)

    cfg = compat_cfg(cfg)

    # set multi-process settings
    setup_multi_processes(cfg)

    # import modules from plguin/xx, registry will be updated
    if hasattr(cfg, 'plugin'):
        if cfg.plugin:
            import importlib
            if hasattr(cfg, 'plugin_dir'):
                plugin_dir = cfg.plugin_dir
                _module_dir = os.path.dirname(plugin_dir)
                _module_dir = _module_dir.split('/')
                _module_path = _module_dir[0]

                for m in _module_dir[1:]:
                    _module_path = _module_path + '.' + m
                print(_module_path)
                plg_lib = importlib.import_module(_module_path)
            else:
                # import dir is the dirpath for the config file
                _module_dir = os.path.dirname(config)
                _module_dir = _module_dir.split('/')
                _module_path = _module_dir[0]
                for m in _module_dir[1:]:
                    _module_path = _module_path + '.' + m
                # print(_module_path)
                plg_lib = importlib.import_module(_module_path)

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    if IS_CUDA_AVAILABLE or IS_MLU_AVAILABLE:
        device = "cuda"
    else:
        device = "cpu"
    model = init_detector(cfg, None, device=device)
    model.CLASSES = CocoDataset.CLASSES

    results = inference_detector(model, img)
    visualize = model.show_result(
        img,
        results,
        bbox_color=CocoDataset.PALETTE,
        text_color=CocoDataset.PALETTE,
        mask_color=CocoDataset.PALETTE,
        show=False,
        out_file=None,
        score_thr=0.3
    )
    del model
    return visualize


description = """
#  <center>Prompt Segment Anything (zero-shot instance segmentation demo)</center>
Github link: [Link](https://github.com/RockeyCoss/Prompt-Segment-Anything)
You can select the model you want to use from the "Model" dropdown menu and click "Submit" to segment the image you uploaded to the "Input Image" box.
"""
if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
    description += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
    
def main():
    with gr.Blocks() as demo:
        gr.Markdown(description)
        with gr.Column():
            with gr.Row():
                with gr.Column():
                    input_img = gr.Image(type="numpy", label="Input Image")
                    model_type = gr.Dropdown(choices=list(config_dict.keys()),
                                             value=list(config_dict.keys())[0],
                                             label='Model',
                                             multiselect=False)
                    with gr.Row():
                        clear_btn = gr.Button(value="Clear")
                        submit_btn = gr.Button(value="Submit")
                output_img = gr.Image(type="numpy", label="Output")
            gr.Examples(
                examples=[["./assets/img1.jpg", "r50-hdetr_sam-vit-b"],
                          ["./assets/img2.jpg", "r50-hdetr_sam-vit-b"],
                          ["./assets/img3.jpg", "r50-hdetr_sam-vit-b"],
                          ["./assets/img4.jpg", "r50-hdetr_sam-vit-b"]],
                inputs=[input_img, model_type],
                outputs=output_img,
                fn=inference
            )

        submit_btn.click(inference,
                         inputs=[input_img, model_type],
                         outputs=output_img)
        clear_btn.click(lambda: [None, None], None, [input_img, output_img], queue=False)

    demo.queue()
    demo.launch()


if __name__ == '__main__':
    main()