Spaces:
Running
Running
zhang-ziang
commited on
Commit
·
6965bae
1
Parent(s):
0f72f6a
confidence added
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ import io
|
|
10 |
from PIL import Image
|
11 |
import rembg
|
12 |
from typing import Any
|
|
|
13 |
|
14 |
|
15 |
from huggingface_hub import hf_hub_download
|
@@ -107,11 +108,31 @@ def get_3angle(image):
|
|
107 |
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
108 |
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
109 |
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
110 |
-
|
|
|
111 |
angles[0] = gaus_ax_pred
|
112 |
angles[1] = gaus_pl_pred - 90
|
113 |
angles[2] = gaus_ro_pred - 30
|
|
|
|
|
|
|
|
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
return angles
|
116 |
|
117 |
def scale(x):
|
@@ -145,10 +166,13 @@ def figure_to_img(fig):
|
|
145 |
image = Image.open(buf).copy()
|
146 |
return image
|
147 |
|
148 |
-
def infer_func(img, do_rm_bkg):
|
149 |
img = Image.fromarray(img)
|
150 |
img = background_preprocess(img, do_rm_bkg)
|
151 |
-
|
|
|
|
|
|
|
152 |
|
153 |
fig, ax = plt.subplots(figsize=(8, 8))
|
154 |
|
@@ -197,21 +221,23 @@ def infer_func(img, do_rm_bkg):
|
|
197 |
|
198 |
res_img = figure_to_img(fig)
|
199 |
# axis_model = "axis.obj"
|
200 |
-
return [res_img, float(angles[0]), float(angles[1]), float(angles[2])]
|
201 |
|
202 |
server = gr.Interface(
|
203 |
flagging_mode='never',
|
204 |
fn=infer_func,
|
205 |
inputs=[
|
206 |
gr.Image(height=512, width=512, label="upload your image"),
|
207 |
-
gr.Checkbox(label="Remove Background", value=True)
|
|
|
208 |
],
|
209 |
outputs=[
|
210 |
gr.Image(height=512, width=512, label="result image"),
|
211 |
# gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
|
212 |
gr.Textbox(lines=1, label='Azimuth(0~360°)'),
|
213 |
gr.Textbox(lines=1, label='Polar(-90~90°)'),
|
214 |
-
gr.Textbox(lines=1, label='Rotation(-90~90°)')
|
|
|
215 |
]
|
216 |
)
|
217 |
|
|
|
10 |
from PIL import Image
|
11 |
import rembg
|
12 |
from typing import Any
|
13 |
+
import torch.nn.functional as F
|
14 |
|
15 |
|
16 |
from huggingface_hub import hf_hub_download
|
|
|
108 |
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
109 |
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
110 |
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
111 |
+
confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0]
|
112 |
+
angles = torch.zeros(4)
|
113 |
angles[0] = gaus_ax_pred
|
114 |
angles[1] = gaus_pl_pred - 90
|
115 |
angles[2] = gaus_ro_pred - 30
|
116 |
+
angles[3] = confidence
|
117 |
+
return angles
|
118 |
+
|
119 |
+
def get_3angle_infer_aug(image):
|
120 |
|
121 |
+
# image = Image.open(image_path).convert('RGB')
|
122 |
+
image_inputs = val_preprocess(images = image)
|
123 |
+
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
|
124 |
+
with torch.no_grad():
|
125 |
+
dino_pred = dino(image_inputs)
|
126 |
+
|
127 |
+
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
128 |
+
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
129 |
+
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
130 |
+
confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0]
|
131 |
+
angles = torch.zeros(4)
|
132 |
+
angles[0] = gaus_ax_pred
|
133 |
+
angles[1] = gaus_pl_pred - 90
|
134 |
+
angles[2] = gaus_ro_pred - 30
|
135 |
+
angles[3] = confidence
|
136 |
return angles
|
137 |
|
138 |
def scale(x):
|
|
|
166 |
image = Image.open(buf).copy()
|
167 |
return image
|
168 |
|
169 |
+
def infer_func(img, do_rm_bkg, do_infer_aug):
|
170 |
img = Image.fromarray(img)
|
171 |
img = background_preprocess(img, do_rm_bkg)
|
172 |
+
if do_infer_aug:
|
173 |
+
angles = get_3angle_infer_aug(img)
|
174 |
+
else:
|
175 |
+
angles = get_3angle(img)
|
176 |
|
177 |
fig, ax = plt.subplots(figsize=(8, 8))
|
178 |
|
|
|
221 |
|
222 |
res_img = figure_to_img(fig)
|
223 |
# axis_model = "axis.obj"
|
224 |
+
return [res_img, float(angles[0]), float(angles[1]), float(angles[2]), float(angles[3])]
|
225 |
|
226 |
server = gr.Interface(
|
227 |
flagging_mode='never',
|
228 |
fn=infer_func,
|
229 |
inputs=[
|
230 |
gr.Image(height=512, width=512, label="upload your image"),
|
231 |
+
gr.Checkbox(label="Remove Background", value=True),
|
232 |
+
gr.Checkbox(label="Inference time augmentation", value=False)
|
233 |
],
|
234 |
outputs=[
|
235 |
gr.Image(height=512, width=512, label="result image"),
|
236 |
# gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
|
237 |
gr.Textbox(lines=1, label='Azimuth(0~360°)'),
|
238 |
gr.Textbox(lines=1, label='Polar(-90~90°)'),
|
239 |
+
gr.Textbox(lines=1, label='Rotation(-90~90°)'),
|
240 |
+
gr.Textbox(lines=1, label='Confidence(0~1)')
|
241 |
]
|
242 |
)
|
243 |
|