Spaces:
Running
Running
zhang-ziang
commited on
Commit
·
74503df
1
Parent(s):
6965bae
infer aug
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ import os
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
import io
|
10 |
from PIL import Image
|
|
|
11 |
import rembg
|
12 |
from typing import Any
|
13 |
import torch.nn.functional as F
|
@@ -97,6 +98,37 @@ def remove_background(image: Image,
|
|
97 |
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
98 |
return image
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
def get_3angle(image):
|
101 |
|
102 |
# image = Image.open(image_path).convert('RGB')
|
@@ -108,7 +140,7 @@ def get_3angle(image):
|
|
108 |
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
109 |
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
110 |
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
111 |
-
confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0]
|
112 |
angles = torch.zeros(4)
|
113 |
angles[0] = gaus_ax_pred
|
114 |
angles[1] = gaus_pl_pred - 90
|
@@ -116,18 +148,86 @@ def get_3angle(image):
|
|
116 |
angles[3] = confidence
|
117 |
return angles
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
def get_3angle_infer_aug(image):
|
120 |
|
121 |
# image = Image.open(image_path).convert('RGB')
|
|
|
122 |
image_inputs = val_preprocess(images = image)
|
123 |
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
|
124 |
with torch.no_grad():
|
125 |
dino_pred = dino(image_inputs)
|
126 |
|
127 |
-
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
128 |
-
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
129 |
-
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
131 |
angles = torch.zeros(4)
|
132 |
angles[0] = gaus_ax_pred
|
133 |
angles[1] = gaus_pl_pred - 90
|
@@ -221,7 +321,7 @@ def infer_func(img, do_rm_bkg, do_infer_aug):
|
|
221 |
|
222 |
res_img = figure_to_img(fig)
|
223 |
# axis_model = "axis.obj"
|
224 |
-
return [res_img, float(angles[0]), float(angles[1]), float(angles[2]), float(angles[3])]
|
225 |
|
226 |
server = gr.Interface(
|
227 |
flagging_mode='never',
|
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
import io
|
10 |
from PIL import Image
|
11 |
+
import random
|
12 |
import rembg
|
13 |
from typing import Any
|
14 |
import torch.nn.functional as F
|
|
|
98 |
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
99 |
return image
|
100 |
|
101 |
+
def random_crop(image, crop_scale=(0.8, 0.95)):
|
102 |
+
"""
|
103 |
+
随机裁切图片
|
104 |
+
image (numpy.ndarray): (H, W, C)。
|
105 |
+
crop_scale (tuple): (min_scale, max_scale)。
|
106 |
+
"""
|
107 |
+
assert isinstance(image, Image.Image), "iput must be PIL.Image.Image"
|
108 |
+
assert len(crop_scale) == 2 and 0 < crop_scale[0] <= crop_scale[1] <= 1
|
109 |
+
|
110 |
+
width, height = image.size
|
111 |
+
|
112 |
+
# 计算裁切的高度和宽度
|
113 |
+
crop_width = random.randint(int(width * crop_scale[0]), int(width * crop_scale[1]))
|
114 |
+
crop_height = random.randint(int(height * crop_scale[0]), int(height * crop_scale[1]))
|
115 |
+
|
116 |
+
# 随机选择裁切的起始点
|
117 |
+
left = random.randint(0, width - crop_width)
|
118 |
+
top = random.randint(0, height - crop_height)
|
119 |
+
|
120 |
+
# 裁切图片
|
121 |
+
cropped_image = image.crop((left, top, left + crop_width, top + crop_height))
|
122 |
+
|
123 |
+
return cropped_image
|
124 |
+
|
125 |
+
def get_crop_images(img, num=3):
|
126 |
+
cropped_images = []
|
127 |
+
for i in range(num):
|
128 |
+
cropped_images.append(random_crop(img))
|
129 |
+
return cropped_images
|
130 |
+
|
131 |
+
|
132 |
def get_3angle(image):
|
133 |
|
134 |
# image = Image.open(image_path).convert('RGB')
|
|
|
140 |
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
|
141 |
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
|
142 |
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
|
143 |
+
confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0][0]
|
144 |
angles = torch.zeros(4)
|
145 |
angles[0] = gaus_ax_pred
|
146 |
angles[1] = gaus_pl_pred - 90
|
|
|
148 |
angles[3] = confidence
|
149 |
return angles
|
150 |
|
151 |
+
def remove_outliers_and_average(tensor, threshold=1.5):
|
152 |
+
assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
|
153 |
+
|
154 |
+
q1 = torch.quantile(tensor, 0.25)
|
155 |
+
q3 = torch.quantile(tensor, 0.75)
|
156 |
+
iqr = q3 - q1
|
157 |
+
|
158 |
+
lower_bound = q1 - threshold * iqr
|
159 |
+
upper_bound = q3 + threshold * iqr
|
160 |
+
|
161 |
+
non_outliers = tensor[(tensor >= lower_bound) & (tensor <= upper_bound)]
|
162 |
+
|
163 |
+
if len(non_outliers) == 0:
|
164 |
+
return tensor.mean().item()
|
165 |
+
|
166 |
+
return non_outliers.mean().item()
|
167 |
+
|
168 |
+
|
169 |
+
def remove_outliers_and_average_circular(tensor, threshold=1.5):
|
170 |
+
assert tensor.dim() == 1, "dimension of input Tensor must equal to 1"
|
171 |
+
|
172 |
+
# 将角度转换为二维平面上的点
|
173 |
+
radians = tensor * torch.pi / 180.0
|
174 |
+
x_coords = torch.cos(radians)
|
175 |
+
y_coords = torch.sin(radians)
|
176 |
+
|
177 |
+
# 计算平均向量
|
178 |
+
mean_x = torch.mean(x_coords)
|
179 |
+
mean_y = torch.mean(y_coords)
|
180 |
+
|
181 |
+
differences = torch.sqrt((x_coords - mean_x) * (x_coords - mean_x) + (y_coords - mean_y) * (y_coords - mean_y))
|
182 |
+
|
183 |
+
# 计算四分位数和 IQR
|
184 |
+
q1 = torch.quantile(differences, 0.25)
|
185 |
+
q3 = torch.quantile(differences, 0.75)
|
186 |
+
iqr = q3 - q1
|
187 |
+
|
188 |
+
# 计算上下限
|
189 |
+
lower_bound = q1 - threshold * iqr
|
190 |
+
upper_bound = q3 + threshold * iqr
|
191 |
+
|
192 |
+
# 筛选非离群点
|
193 |
+
non_outliers = tensor[(differences >= lower_bound) & (differences <= upper_bound)]
|
194 |
+
|
195 |
+
if len(non_outliers) == 0:
|
196 |
+
mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
|
197 |
+
mean_angle = (mean_angle + 360) % 360
|
198 |
+
return mean_angle # 如果没有非离群点,返回 None
|
199 |
+
|
200 |
+
# 对非离群点再次计算平均向量
|
201 |
+
radians = non_outliers * torch.pi / 180.0
|
202 |
+
x_coords = torch.cos(radians)
|
203 |
+
y_coords = torch.sin(radians)
|
204 |
+
|
205 |
+
mean_x = torch.mean(x_coords)
|
206 |
+
mean_y = torch.mean(y_coords)
|
207 |
+
|
208 |
+
mean_angle = torch.atan2(mean_y, mean_x) * 180.0 / torch.pi
|
209 |
+
mean_angle = (mean_angle + 360) % 360
|
210 |
+
|
211 |
+
return mean_angle
|
212 |
+
|
213 |
def get_3angle_infer_aug(image):
|
214 |
|
215 |
# image = Image.open(image_path).convert('RGB')
|
216 |
+
image = get_crop_images(image, num=6)
|
217 |
image_inputs = val_preprocess(images = image)
|
218 |
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
|
219 |
with torch.no_grad():
|
220 |
dino_pred = dino(image_inputs)
|
221 |
|
222 |
+
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1).to(torch.float32)
|
223 |
+
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1).to(torch.float32)
|
224 |
+
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1).to(torch.float32)
|
225 |
+
|
226 |
+
gaus_ax_pred = remove_outliers_and_average_circular(gaus_ax_pred)
|
227 |
+
gaus_pl_pred = remove_outliers_and_average(gaus_pl_pred)
|
228 |
+
gaus_ro_pred = remove_outliers_and_average(gaus_ro_pred)
|
229 |
+
|
230 |
+
confidence = torch.mean(F.softmax(dino_pred[:, -2:], dim=-1), dim=0)[0]
|
231 |
angles = torch.zeros(4)
|
232 |
angles[0] = gaus_ax_pred
|
233 |
angles[1] = gaus_pl_pred - 90
|
|
|
321 |
|
322 |
res_img = figure_to_img(fig)
|
323 |
# axis_model = "axis.obj"
|
324 |
+
return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]
|
325 |
|
326 |
server = gr.Interface(
|
327 |
flagging_mode='never',
|