Kaori1707 commited on
Commit
76daa54
1 Parent(s): 72db49b

add segment model

Browse files
Files changed (1) hide show
  1. app.py +49 -12
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import torch
4
  from torchvision.transforms import Compose
5
  import cv2
6
- from dpt.models import DPTDepthModel
7
  from dpt.transforms import Resize, NormalizeImage, PrepareForNet
8
  import os
9
 
@@ -11,16 +11,31 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  print("device: %s" % device)
12
  default_models = {
13
  "dpt_hybrid": "weights/dpt_hybrid-midas-501f0c75.pt",
 
14
  }
15
  torch.backends.cudnn.enabled = True
16
  torch.backends.cudnn.benchmark = True
17
- net_w = net_h = 384
18
- model = DPTDepthModel(
19
  path=default_models["dpt_hybrid"],
20
  backbone="vitb_rn50_384",
21
  non_negative=True,
22
  enable_attention_hooks=False,
23
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
25
  transform = Compose(
26
  [
@@ -38,8 +53,6 @@ transform = Compose(
38
  ]
39
  )
40
 
41
- model.eval()
42
- model.to(device)
43
 
44
  def write_depth(depth, bits=1, absolute_depth=False):
45
  """Write depth map to pfm and png file.
@@ -67,7 +80,8 @@ def write_depth(depth, bits=1, absolute_depth=False):
67
  return out.astype("uint8")
68
  elif bits == 2:
69
  return out.astype("uint16")
70
-
 
71
 
72
  def DPT(image):
73
  img_input = transform({"image": image})["image"]
@@ -75,7 +89,7 @@ def DPT(image):
75
  with torch.no_grad():
76
  sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
77
 
78
- prediction = model.forward(sample)
79
  prediction = (
80
  torch.nn.functional.interpolate(
81
  prediction.unsqueeze(1),
@@ -90,6 +104,26 @@ def DPT(image):
90
 
91
  depth_img = write_depth(prediction, bits=2)
92
  return depth_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  title = " AISeed AI Application Demo "
95
  description = "# A Demo of Deep Learning for Depth Estimation"
@@ -99,16 +133,19 @@ with gr.Blocks() as demo:
99
  demo.title = title
100
  gr.Markdown(description)
101
  with gr.Row():
102
- im = gr.Image(label="Input Image")
103
- im_2 = gr.Image(label="Depth Image")
104
  with gr.Column():
105
-
 
 
 
 
106
  btn1 = gr.Button(value="Depth Estimator")
107
  btn1.click(DPT, inputs=[im], outputs=[im_2])
 
 
108
  gr.Examples(examples=example_list,
109
  inputs=[im],
110
- outputs=[im_2],
111
- fn=DPT)
112
 
113
  if __name__ == "__main__":
114
  demo.launch()
 
3
  import torch
4
  from torchvision.transforms import Compose
5
  import cv2
6
+ from dpt.models import DPTDepthModel, DPTSegmentationModel
7
  from dpt.transforms import Resize, NormalizeImage, PrepareForNet
8
  import os
9
 
 
11
  print("device: %s" % device)
12
  default_models = {
13
  "dpt_hybrid": "weights/dpt_hybrid-midas-501f0c75.pt",
14
+ "segment_hybrid": "weights/dpt_hybrid-ade20k-53898607.pt"
15
  }
16
  torch.backends.cudnn.enabled = True
17
  torch.backends.cudnn.benchmark = True
18
+
19
+ depth_model = DPTDepthModel(
20
  path=default_models["dpt_hybrid"],
21
  backbone="vitb_rn50_384",
22
  non_negative=True,
23
  enable_attention_hooks=False,
24
  )
25
+
26
+ depth_model.eval()
27
+ depth_model.to(device)
28
+
29
+ seg_model = DPTSegmentationModel(
30
+ 150,
31
+ path=default_models["segment_hybrid"],
32
+ backbone="vitb_rn50_384",
33
+ )
34
+ seg_model.eval()
35
+ seg_model.to(device)
36
+
37
+ # Transform
38
+ net_w = net_h = 384
39
  normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40
  transform = Compose(
41
  [
 
53
  ]
54
  )
55
 
 
 
56
 
57
  def write_depth(depth, bits=1, absolute_depth=False):
58
  """Write depth map to pfm and png file.
 
80
  return out.astype("uint8")
81
  elif bits == 2:
82
  return out.astype("uint16")
83
+
84
+
85
 
86
  def DPT(image):
87
  img_input = transform({"image": image})["image"]
 
89
  with torch.no_grad():
90
  sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
91
 
92
+ prediction = depth_model.forward(sample)
93
  prediction = (
94
  torch.nn.functional.interpolate(
95
  prediction.unsqueeze(1),
 
104
 
105
  depth_img = write_depth(prediction, bits=2)
106
  return depth_img
107
+
108
+ def Segment(image):
109
+ img_input = transform({"image": image})["image"]
110
+
111
+ # compute
112
+ with torch.no_grad():
113
+ sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
114
+ # if optimize == True and device == torch.device("cuda"):
115
+ # sample = sample.to(memory_format=torch.channels_last)
116
+ # sample = sample.half()
117
+
118
+ out = seg_model.forward(sample)
119
+
120
+ prediction = torch.nn.functional.interpolate(
121
+ out, size=image.shape[:2], mode="bicubic", align_corners=False
122
+ )
123
+ prediction = torch.argmax(prediction, dim=1) + 1
124
+ prediction = prediction.squeeze().cpu().numpy()
125
+
126
+ return prediction
127
 
128
  title = " AISeed AI Application Demo "
129
  description = "# A Demo of Deep Learning for Depth Estimation"
 
133
  demo.title = title
134
  gr.Markdown(description)
135
  with gr.Row():
 
 
136
  with gr.Column():
137
+
138
+ im_2 = gr.Image(label="Depth Image")
139
+ im_3 = gr.Image(label="Segment Image")
140
+ with gr.Column():
141
+ im = gr.Image(label="Input Image")
142
  btn1 = gr.Button(value="Depth Estimator")
143
  btn1.click(DPT, inputs=[im], outputs=[im_2])
144
+ btn2 = gr.Button(value="Segment")
145
+ btn2.click(Segment, inputs=[im], outputs=[im_3])
146
  gr.Examples(examples=example_list,
147
  inputs=[im],
148
+ outputs=[im_2])
 
149
 
150
  if __name__ == "__main__":
151
  demo.launch()