zhang-ziang commited on
Commit
6965bae
·
1 Parent(s): 0f72f6a

confidence added

Browse files
Files changed (1) hide show
  1. app.py +32 -6
app.py CHANGED
@@ -10,6 +10,7 @@ import io
10
  from PIL import Image
11
  import rembg
12
  from typing import Any
 
13
 
14
 
15
  from huggingface_hub import hf_hub_download
@@ -107,11 +108,31 @@ def get_3angle(image):
107
  gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
108
  gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
109
  gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
110
- angles = torch.zeros(3)
 
111
  angles[0] = gaus_ax_pred
112
  angles[1] = gaus_pl_pred - 90
113
  angles[2] = gaus_ro_pred - 30
 
 
 
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  return angles
116
 
117
  def scale(x):
@@ -145,10 +166,13 @@ def figure_to_img(fig):
145
  image = Image.open(buf).copy()
146
  return image
147
 
148
- def infer_func(img, do_rm_bkg):
149
  img = Image.fromarray(img)
150
  img = background_preprocess(img, do_rm_bkg)
151
- angles = get_3angle(img)
 
 
 
152
 
153
  fig, ax = plt.subplots(figsize=(8, 8))
154
 
@@ -197,21 +221,23 @@ def infer_func(img, do_rm_bkg):
197
 
198
  res_img = figure_to_img(fig)
199
  # axis_model = "axis.obj"
200
- return [res_img, float(angles[0]), float(angles[1]), float(angles[2])]
201
 
202
  server = gr.Interface(
203
  flagging_mode='never',
204
  fn=infer_func,
205
  inputs=[
206
  gr.Image(height=512, width=512, label="upload your image"),
207
- gr.Checkbox(label="Remove Background", value=True)
 
208
  ],
209
  outputs=[
210
  gr.Image(height=512, width=512, label="result image"),
211
  # gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
212
  gr.Textbox(lines=1, label='Azimuth(0~360°)'),
213
  gr.Textbox(lines=1, label='Polar(-90~90°)'),
214
- gr.Textbox(lines=1, label='Rotation(-90~90°)')
 
215
  ]
216
  )
217
 
 
10
  from PIL import Image
11
  import rembg
12
  from typing import Any
13
+ import torch.nn.functional as F
14
 
15
 
16
  from huggingface_hub import hf_hub_download
 
108
  gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
109
  gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
110
  gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
111
+ confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0]
112
+ angles = torch.zeros(4)
113
  angles[0] = gaus_ax_pred
114
  angles[1] = gaus_pl_pred - 90
115
  angles[2] = gaus_ro_pred - 30
116
+ angles[3] = confidence
117
+ return angles
118
+
119
+ def get_3angle_infer_aug(image):
120
 
121
+ # image = Image.open(image_path).convert('RGB')
122
+ image_inputs = val_preprocess(images = image)
123
+ image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
124
+ with torch.no_grad():
125
+ dino_pred = dino(image_inputs)
126
+
127
+ gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
128
+ gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
129
+ gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
130
+ confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0]
131
+ angles = torch.zeros(4)
132
+ angles[0] = gaus_ax_pred
133
+ angles[1] = gaus_pl_pred - 90
134
+ angles[2] = gaus_ro_pred - 30
135
+ angles[3] = confidence
136
  return angles
137
 
138
  def scale(x):
 
166
  image = Image.open(buf).copy()
167
  return image
168
 
169
+ def infer_func(img, do_rm_bkg, do_infer_aug):
170
  img = Image.fromarray(img)
171
  img = background_preprocess(img, do_rm_bkg)
172
+ if do_infer_aug:
173
+ angles = get_3angle_infer_aug(img)
174
+ else:
175
+ angles = get_3angle(img)
176
 
177
  fig, ax = plt.subplots(figsize=(8, 8))
178
 
 
221
 
222
  res_img = figure_to_img(fig)
223
  # axis_model = "axis.obj"
224
+ return [res_img, float(angles[0]), float(angles[1]), float(angles[2]), float(angles[3])]
225
 
226
  server = gr.Interface(
227
  flagging_mode='never',
228
  fn=infer_func,
229
  inputs=[
230
  gr.Image(height=512, width=512, label="upload your image"),
231
+ gr.Checkbox(label="Remove Background", value=True),
232
+ gr.Checkbox(label="Inference time augmentation", value=False)
233
  ],
234
  outputs=[
235
  gr.Image(height=512, width=512, label="result image"),
236
  # gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
237
  gr.Textbox(lines=1, label='Azimuth(0~360°)'),
238
  gr.Textbox(lines=1, label='Polar(-90~90°)'),
239
+ gr.Textbox(lines=1, label='Rotation(-90~90°)'),
240
+ gr.Textbox(lines=1, label='Confidence(0~1)')
241
  ]
242
  )
243