|
from typing import Tuple |
|
import gradio as gr |
|
import os |
|
from PIL import Image |
|
from datasets import load_dataset |
|
from prepare_samples import prepare_samples |
|
|
|
DIR_PATH = os.path.dirname(__file__) |
|
|
|
|
|
def inference(rgb: Image.Image, depth: Image.Image) -> Tuple[Image.Image]: |
|
return (rgb,) |
|
|
|
|
|
dataset = load_dataset("RGBD-SOD/test", "v1", split="train", cache_dir="data") |
|
|
|
|
|
|
|
|
|
|
|
TITLE = "BBS-Net Demo" |
|
DESCRIPTION = "Gradio demo for BBS-Net: RGB-D salient object detection with a bifurcated backbone strategy network." |
|
examples = prepare_samples() |
|
|
|
demo = gr.Interface( |
|
fn=inference, |
|
inputs=[ |
|
gr.inputs.Image(label="RGB", type="pil"), |
|
gr.inputs.Image(label="Depth", type="pil"), |
|
], |
|
outputs=[ |
|
gr.outputs.Image(label="Prediction", type="pil"), |
|
], |
|
title=TITLE, |
|
examples=examples, |
|
description=DESCRIPTION, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(enable_queue=True, cache_examples=False) |
|
|