thinh-huynh-re commited on
Commit
f092df4
1 Parent(s): 60b4128
Files changed (6) hide show
  1. .gitignore +8 -0
  2. app.py +25 -0
  3. inference.py +35 -0
  4. prepare_samples.py +31 -0
  5. requirements.txt +8 -0
  6. samples/.gitkeep +0 -0
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ env
2
+ __pycache__
3
+ data
4
+ samples/*
5
+ !samples/.gitkeep
6
+ model_cache
7
+ image_processor_cache
8
+ flagged
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from inference import inference
4
+ from prepare_samples import prepare_samples
5
+
6
+ TITLE = "DPT Depth"
7
+ DESCRIPTION = "Gradio demo for DPT-Depth"
8
+ examples = prepare_samples()
9
+
10
+ demo = gr.Interface(
11
+ fn=inference,
12
+ inputs=[
13
+ gr.inputs.Image(label="RGB", type="pil"),
14
+ ],
15
+ outputs=[
16
+ gr.outputs.Image(label="Prediction", type="pil"),
17
+ ],
18
+ title=TITLE,
19
+ examples=examples,
20
+ description=DESCRIPTION,
21
+ )
22
+
23
+
24
+ if __name__ == "__main__":
25
+ demo.launch(server_name="0.0.0.0") # server_port=8541
inference.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor, AutoModel
2
+ from typing import Dict
3
+
4
+ import numpy as np
5
+ from matplotlib import cm
6
+ from PIL import Image
7
+ from torch import Tensor
8
+
9
+ model = AutoModel.from_pretrained(
10
+ "RGBD-SOD/dptdepth", trust_remote_code=True, cache_dir="model_cache"
11
+ )
12
+ image_processor = AutoImageProcessor.from_pretrained(
13
+ "RGBD-SOD/dptdepth", trust_remote_code=True, cache_dir="image_processor_cache"
14
+ )
15
+
16
+
17
+ def inference(rgb: Image.Image) -> Image.Image:
18
+ rgb = rgb.convert(mode="RGB")
19
+
20
+ preprocessed_sample: Dict[str, Tensor] = image_processor.preprocess(
21
+ {
22
+ "rgb": rgb,
23
+ }
24
+ )
25
+
26
+ output: Dict[str, Tensor] = model(preprocessed_sample["rgb"])
27
+ postprocessed_sample: np.ndarray = image_processor.postprocess(
28
+ output["logits"], [rgb.size[1], rgb.size[0]]
29
+ )
30
+ prediction = Image.fromarray(np.uint8(cm.gist_earth(postprocessed_sample) * 255))
31
+ return prediction
32
+
33
+
34
+ if __name__ == "__main__":
35
+ pass
prepare_samples.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from typing import List, Tuple
4
+
5
+ from PIL import Image
6
+ from datasets import load_dataset
7
+
8
+
9
+ dataset = load_dataset("RGBD-SOD/test", "v1", split="train", cache_dir="data")
10
+ SAMPLES_DIR = "samples"
11
+
12
+
13
+ def prepare_samples():
14
+ samples: List[Tuple[str, str, str]] = []
15
+ for sample in dataset:
16
+ rgb: Image.Image = sample["rgb"]
17
+ depth: Image.Image = sample["depth"]
18
+ gt: Image.Image = sample["gt"]
19
+ name: str = sample["name"]
20
+ dir_path = os.path.join(SAMPLES_DIR, name)
21
+ shutil.rmtree(dir_path, ignore_errors=True)
22
+ os.makedirs(dir_path, exist_ok=True)
23
+ rgb_path = os.path.join(dir_path, f"rgb.jpg")
24
+ rgb.save(rgb_path)
25
+ depth_path = os.path.join(dir_path, f"depth.jpg")
26
+ depth.save(depth_path)
27
+ gt_path = os.path.join(dir_path, f"gt.png")
28
+ gt.save(gt_path)
29
+
30
+ samples.append([rgb_path, depth_path, gt_path])
31
+ return samples
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ opencv-python
4
+ transformers[torch]
5
+ torchvision
6
+ datasets
7
+ matplotlib
8
+ timm
samples/.gitkeep ADDED
File without changes