thinh-researcher commited on
Commit
e5c7710
1 Parent(s): 58803c8
Files changed (4) hide show
  1. .gitignore +3 -0
  2. app.py +41 -0
  3. prepare_samples.py +29 -0
  4. requirements.txt +2 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ env
2
+ __pycache__
3
+ data
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import gradio as gr
3
+ import os
4
+ from PIL import Image
5
+ from datasets import load_dataset
6
+ from prepare_samples import prepare_samples
7
+
8
+ DIR_PATH = os.path.dirname(__file__)
9
+
10
+
11
+ def inference(rgb: Image.Image, depth: Image.Image) -> Tuple[Image.Image]:
12
+ return (rgb,)
13
+
14
+
15
+ dataset = load_dataset("RGBD-SOD/test", "v1", split="train", cache_dir="data")
16
+
17
+ # with gr.Blocks() as demo:
18
+ # with gr.Row(elem_id="center"):
19
+ # gr.Markdown("# BBS-Net Demo")
20
+
21
+ TITLE = "BBS-Net Demo"
22
+ DESCRIPTION = "Gradio demo for BBS-Net: RGB-D salient object detection with a bifurcated backbone strategy network."
23
+ examples = prepare_samples()
24
+
25
+ demo = gr.Interface(
26
+ fn=inference,
27
+ inputs=[
28
+ gr.inputs.Image(label="RGB", type="pil"),
29
+ gr.inputs.Image(label="Depth", type="pil"),
30
+ ],
31
+ outputs=[
32
+ gr.outputs.Image(label="Prediction", type="pil"),
33
+ ],
34
+ title=TITLE,
35
+ examples=examples,
36
+ description=DESCRIPTION,
37
+ )
38
+
39
+
40
+ if __name__ == "__main__":
41
+ demo.launch(enable_queue=True, cache_examples=False)
prepare_samples.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from datasets import load_dataset
3
+ from PIL import Image
4
+ import os
5
+ import shutil
6
+
7
+ dataset = load_dataset("RGBD-SOD/test", "v1", split="train", cache_dir="data")
8
+ SAMPLES_DIR = "samples"
9
+
10
+
11
+ def prepare_samples():
12
+ samples: List[Tuple[str, str, str]] = []
13
+ for sample in dataset:
14
+ rgb: Image.Image = sample["rgb"]
15
+ depth: Image.Image = sample["depth"]
16
+ gt: Image.Image = sample["gt"]
17
+ name: str = sample["name"]
18
+ dir_path = os.path.join(SAMPLES_DIR, name)
19
+ shutil.rmtree(dir_path, ignore_errors=True)
20
+ os.makedirs(dir_path, exist_ok=True)
21
+ rgb_path = os.path.join(dir_path, f"rgb.jpg")
22
+ rgb.save(rgb_path)
23
+ depth_path = os.path.join(dir_path, f"depth.jpg")
24
+ depth.save(depth_path)
25
+ gt_path = os.path.join(dir_path, f"gt.png")
26
+ gt.save(gt_path)
27
+
28
+ samples.append([rgb_path, depth_path, gt_path])
29
+ return samples
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ datasets