from typing import Tuple import gradio as gr import os from PIL import Image from datasets import load_dataset from prepare_samples import prepare_samples DIR_PATH = os.path.dirname(__file__) def inference(rgb: Image.Image, depth: Image.Image) -> Tuple[Image.Image]: return (rgb,) dataset = load_dataset("RGBD-SOD/test", "v1", split="train", cache_dir="data") # with gr.Blocks() as demo: # with gr.Row(elem_id="center"): # gr.Markdown("# BBS-Net Demo") TITLE = "BBS-Net Demo" DESCRIPTION = "Gradio demo for BBS-Net: RGB-D salient object detection with a bifurcated backbone strategy network." examples = prepare_samples() demo = gr.Interface( fn=inference, inputs=[ gr.inputs.Image(label="RGB", type="pil"), gr.inputs.Image(label="Depth", type="pil"), ], outputs=[ gr.outputs.Image(label="Prediction", type="pil"), ], title=TITLE, examples=examples, description=DESCRIPTION, ) if __name__ == "__main__": demo.launch(enable_queue=True, cache_examples=False)