zhang-ziang commited on
Commit
74503df
·
1 Parent(s): 6965bae
Files changed (1) hide show
  1. app.py +106 -6
app.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  import matplotlib.pyplot as plt
9
  import io
10
  from PIL import Image
 
11
  import rembg
12
  from typing import Any
13
  import torch.nn.functional as F
@@ -97,6 +98,37 @@ def remove_background(image: Image,
97
  image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
98
  return image
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def get_3angle(image):
101
 
102
  # image = Image.open(image_path).convert('RGB')
@@ -108,7 +140,7 @@ def get_3angle(image):
108
  gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
109
  gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
110
  gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
111
- confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0]
112
  angles = torch.zeros(4)
113
  angles[0] = gaus_ax_pred
114
  angles[1] = gaus_pl_pred - 90
@@ -116,18 +148,86 @@ def get_3angle(image):
116
  angles[3] = confidence
117
  return angles
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def get_3angle_infer_aug(image):
120
 
121
  # image = Image.open(image_path).convert('RGB')
 
122
  image_inputs = val_preprocess(images = image)
123
  image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
124
  with torch.no_grad():
125
  dino_pred = dino(image_inputs)
126
 
127
- gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
128
- gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
129
- gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
130
- confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0]
 
 
 
 
 
131
  angles = torch.zeros(4)
132
  angles[0] = gaus_ax_pred
133
  angles[1] = gaus_pl_pred - 90
@@ -221,7 +321,7 @@ def infer_func(img, do_rm_bkg, do_infer_aug):
221
 
222
  res_img = figure_to_img(fig)
223
  # axis_model = "axis.obj"
224
- return [res_img, float(angles[0]), float(angles[1]), float(angles[2]), float(angles[3])]
225
 
226
  server = gr.Interface(
227
  flagging_mode='never',
 
8
  import matplotlib.pyplot as plt
9
  import io
10
  from PIL import Image
11
+ import random
12
  import rembg
13
  from typing import Any
14
  import torch.nn.functional as F
 
98
  image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
99
  return image
100
 
101
+ def random_crop(image, crop_scale=(0.8, 0.95)):
102
+ """
103
+ 随机裁切图片
104
+ image (numpy.ndarray): (H, W, C)。
105
+ crop_scale (tuple): (min_scale, max_scale)。
106
+ """
107
+ assert isinstance(image, Image.Image), "iput must be PIL.Image.Image"
108
+ assert len(crop_scale) == 2 and 0 < crop_scale[0] <= crop_scale[1] <= 1
109
+
110
+ width, height = image.size
111
+
112
+ # 计算裁切的高度和宽度
113
+ crop_width = random.randint(int(width * crop_scale[0]), int(width * crop_scale[1]))
114
+ crop_height = random.randint(int(height * crop_scale[0]), int(height * crop_scale[1]))
115
+
116
+ # 随机选择裁切的起始点
117
+ left = random.randint(0, width - crop_width)
118
+ top = random.randint(0, height - crop_height)
119
+
120
+ # 裁切图片
121
+ cropped_image = image.crop((left, top, left + crop_width, top + crop_height))
122
+
123
+ return cropped_image
124
+
125
+ def get_crop_images(img, num=3):
126
+ cropped_images = []
127
+ for i in range(num):
128
+ cropped_images.append(random_crop(img))
129
+ return cropped_images
130
+
131
+
132
  def get_3angle(image):
133
 
134
  # image = Image.open(image_path).convert('RGB')
 
140
  gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
141
  gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
142
  gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
143
+ confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0][0]
144
  angles = torch.zeros(4)
145
  angles[0] = gaus_ax_pred
146
  angles[1] = gaus_pl_pred - 90
 
148
  angles[3] = confidence
149
  return angles
150
 
151
+ def remove_outliers_and_average(tensor, threshold=1.5):
152
+ assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
153
+
154
+ q1 = torch.quantile(tensor, 0.25)
155
+ q3 = torch.quantile(tensor, 0.75)
156
+ iqr = q3 - q1
157
+
158
+ lower_bound = q1 - threshold * iqr
159
+ upper_bound = q3 + threshold * iqr
160
+
161
+ non_outliers = tensor[(tensor >= lower_bound) & (tensor <= upper_bound)]
162
+
163
+ if len(non_outliers) == 0:
164
+ return tensor.mean().item()
165
+
166
+ return non_outliers.mean().item()
167
+
168
+
169
+ def remove_outliers_and_average_circular(tensor, threshold=1.5):
170
+ assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
171
+
172
+ # 将角度转换为二维平面上的点
173
+ radians = tensor * torch.pi / 180.0
174
+ x_coords = torch.cos(radians)
175
+ y_coords = torch.sin(radians)
176
+
177
+ # 计算平均向量
178
+ mean_x = torch.mean(x_coords)
179
+ mean_y = torch.mean(y_coords)
180
+
181
+ differences = torch.sqrt((x_coords - mean_x) * (x_coords - mean_x) + (y_coords - mean_y) * (y_coords - mean_y))
182
+
183
+ # 计算四分位数和 IQR
184
+ q1 = torch.quantile(differences, 0.25)
185
+ q3 = torch.quantile(differences, 0.75)
186
+ iqr = q3 - q1
187
+
188
+ # 计算上下限
189
+ lower_bound = q1 - threshold * iqr
190
+ upper_bound = q3 + threshold * iqr
191
+
192
+ # 筛选非离群点
193
+ non_outliers = tensor[(differences >= lower_bound) & (differences <= upper_bound)]
194
+
195
+ if len(non_outliers) == 0:
196
+ mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
197
+ mean_angle = (mean_angle + 360) % 360
198
+ return mean_angle # 如果没有非离群点,返回 None
199
+
200
+ # 对非离群点再次计算平均向量
201
+ radians = non_outliers * torch.pi / 180.0
202
+ x_coords = torch.cos(radians)
203
+ y_coords = torch.sin(radians)
204
+
205
+ mean_x = torch.mean(x_coords)
206
+ mean_y = torch.mean(y_coords)
207
+
208
+ mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
209
+ mean_angle = (mean_angle + 360) % 360
210
+
211
+ return mean_angle
212
+
213
  def get_3angle_infer_aug(image):
214
 
215
  # image = Image.open(image_path).convert('RGB')
216
+ image = get_crop_images(image, num=6)
217
  image_inputs = val_preprocess(images = image)
218
  image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
219
  with torch.no_grad():
220
  dino_pred = dino(image_inputs)
221
 
222
+ gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1).to(torch.float32)
223
+ gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1).to(torch.float32)
224
+ gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1).to(torch.float32)
225
+
226
+ gaus_ax_pred = remove_outliers_and_average_circular(gaus_ax_pred)
227
+ gaus_pl_pred = remove_outliers_and_average(gaus_pl_pred)
228
+ gaus_ro_pred = remove_outliers_and_average(gaus_ro_pred)
229
+
230
+ confidence = torch.mean(F.softmax(dino_pred[:, -2:], dim=-1), dim=0)[0]
231
  angles = torch.zeros(4)
232
  angles[0] = gaus_ax_pred
233
  angles[1] = gaus_pl_pred - 90
 
321
 
322
  res_img = figure_to_img(fig)
323
  # axis_model = "axis.obj"
324
+ return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]
325
 
326
  server = gr.Interface(
327
  flagging_mode='never',