Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os | |
from collections import OrderedDict | |
import torch | |
from mmcv import Config | |
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE | |
from mmdet.apis import init_detector, inference_detector | |
from mmdet.datasets import (CocoDataset) | |
from mmdet.utils import (compat_cfg, replace_cfg_vals, setup_multi_processes, | |
update_data_root) | |
import gradio as gr | |
config_dict = OrderedDict([('swin-l-hdetr_sam-vit-b', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py'), | |
('swin-l-hdetr_sam-vit-l', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py'), | |
('swin-l-hdetr_sam-vit-h', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py'), | |
('focalnet-l-dino_sam-vit-b', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py'), | |
('focalnet-l-dino_sam-vit-l', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py'), | |
( | |
'focalnet-l-dino_sam-vit-h', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py')]) | |
def inference(img, config): | |
if img is None: | |
return None | |
config = config_dict[config] | |
cfg = Config.fromfile(config) | |
# replace the ${key} with the value of cfg.key | |
cfg = replace_cfg_vals(cfg) | |
# update data root according to MMDET_DATASETS | |
update_data_root(cfg) | |
cfg = compat_cfg(cfg) | |
# set multi-process settings | |
setup_multi_processes(cfg) | |
# import modules from plguin/xx, registry will be updated | |
if hasattr(cfg, 'plugin'): | |
if cfg.plugin: | |
import importlib | |
if hasattr(cfg, 'plugin_dir'): | |
plugin_dir = cfg.plugin_dir | |
_module_dir = os.path.dirname(plugin_dir) | |
_module_dir = _module_dir.split('/') | |
_module_path = _module_dir[0] | |
for m in _module_dir[1:]: | |
_module_path = _module_path + '.' + m | |
print(_module_path) | |
plg_lib = importlib.import_module(_module_path) | |
else: | |
# import dir is the dirpath for the config file | |
_module_dir = os.path.dirname(config) | |
_module_dir = _module_dir.split('/') | |
_module_path = _module_dir[0] | |
for m in _module_dir[1:]: | |
_module_path = _module_path + '.' + m | |
# print(_module_path) | |
plg_lib = importlib.import_module(_module_path) | |
# set cudnn_benchmark | |
if cfg.get('cudnn_benchmark', False): | |
torch.backends.cudnn.benchmark = True | |
if IS_CUDA_AVAILABLE or IS_MLU_AVAILABLE: | |
device = "cuda" | |
else: | |
device = "cpu" | |
model = init_detector(cfg, None, device=device) | |
model.CLASSES = CocoDataset.CLASSES | |
results = inference_detector(model, img) | |
visualize = model.show_result( | |
img, | |
results, | |
bbox_color=CocoDataset.PALETTE, | |
text_color=CocoDataset.PALETTE, | |
mask_color=CocoDataset.PALETTE, | |
show=False, | |
out_file=None, | |
score_thr=0.3 | |
) | |
del model | |
return visualize | |
description = """ | |
# <center>Prompt Segment Anything (zero-shot instance segmentation demo)</center> | |
Github link: [Link](https://github.com/RockeyCoss/Prompt-Segment-Anything) | |
You can select the model you want to use from the "Model" dropdown menu and click "Submit" to segment the image you uploaded to the "Input Image" box. | |
""" | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown(description) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
input_img = gr.Image(type="numpy", label="Input Image") | |
model_type = gr.Dropdown(choices=list(config_dict.keys()), | |
value=list(config_dict.keys())[0], | |
label='Model', | |
multiselect=False) | |
with gr.Row(): | |
clear_btn = gr.Button(value="Clear") | |
submit_btn = gr.Button(value="Submit") | |
output_img = gr.Image(type="numpy", label="Output") | |
gr.Examples( | |
examples=[["./assets/img1.jpg", "swin-l-hdetr_sam-vit-b"], | |
["./assets/img2.jpg", "swin-l-hdetr_sam-vit-l"], | |
["./assets/img3.jpg", "swin-l-hdetr_sam-vit-l"], | |
["./assets/img4.jpg", "focalnet-l-dino_sam-vit-b"]], | |
inputs=[input_img, model_type], | |
outputs=output_img, | |
fn=inference | |
) | |
submit_btn.click(inference, | |
inputs=[input_img, model_type], | |
outputs=output_img) | |
clear_btn.click(lambda: [None, None], None, [input_img, output_img], queue=False) | |
demo.queue() | |
demo.launch(share=True) | |
if __name__ == '__main__': | |
main() | |