Spaces:
Running
Running
import gradio as gr | |
from paths import * | |
import numpy as np | |
from vision_tower import DINOv2_MLP | |
from transformers import AutoImageProcessor | |
import torch | |
import os | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
import rembg | |
from typing import Any | |
from huggingface_hub import hf_hub_download | |
ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True) | |
print(ckpt_path) | |
save_path = './' | |
device = 'cpu' | |
dino = DINOv2_MLP( | |
dino_mode = 'large', | |
in_dim = 1024, | |
out_dim = 360+180+60+2, | |
evaluate = True, | |
mask_dino = False, | |
frozen_back = False | |
).to(device) | |
dino.eval() | |
print('model create') | |
dino.load_state_dict(torch.load(ckpt_path, map_location='cpu')) | |
print('weight loaded') | |
val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./') | |
def background_preprocess(input_image, do_remove_background): | |
rembg_session = rembg.new_session() if do_remove_background else None | |
if do_remove_background: | |
input_image = remove_background(input_image, rembg_session) | |
input_image = resize_foreground(input_image, 0.85) | |
return input_image | |
def resize_foreground( | |
image: Image, | |
ratio: float, | |
) -> Image: | |
image = np.array(image) | |
assert image.shape[-1] == 4 | |
alpha = np.where(image[..., 3] > 0) | |
y1, y2, x1, x2 = ( | |
alpha[0].min(), | |
alpha[0].max(), | |
alpha[1].min(), | |
alpha[1].max(), | |
) | |
# crop the foreground | |
fg = image[y1:y2, x1:x2] | |
# pad to square | |
size = max(fg.shape[0], fg.shape[1]) | |
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 | |
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 | |
new_image = np.pad( | |
fg, | |
((ph0, ph1), (pw0, pw1), (0, 0)), | |
mode="constant", | |
constant_values=((0, 0), (0, 0), (0, 0)), | |
) | |
# compute padding according to the ratio | |
new_size = int(new_image.shape[0] / ratio) | |
# pad to size, double side | |
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 | |
ph1, pw1 = new_size - size - ph0, new_size - size - pw0 | |
new_image = np.pad( | |
new_image, | |
((ph0, ph1), (pw0, pw1), (0, 0)), | |
mode="constant", | |
constant_values=((0, 0), (0, 0), (0, 0)), | |
) | |
new_image = Image.fromarray(new_image) | |
return new_image | |
def remove_background(image: Image, | |
rembg_session: Any = None, | |
force: bool = False, | |
**rembg_kwargs, | |
) -> Image: | |
do_remove = True | |
if image.mode == "RGBA" and image.getextrema()[3][0] < 255: | |
do_remove = False | |
do_remove = do_remove or force | |
if do_remove: | |
image = rembg.remove(image, session=rembg_session, **rembg_kwargs) | |
return image | |
def get_3angle(image): | |
# image = Image.open(image_path).convert('RGB') | |
image_inputs = val_preprocess(images = image) | |
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device) | |
with torch.no_grad(): | |
dino_pred = dino(image_inputs) | |
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1) | |
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1) | |
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1) | |
angles = torch.zeros(3) | |
angles[0] = gaus_ax_pred | |
angles[1] = gaus_pl_pred - 90 | |
angles[2] = gaus_ro_pred - 30 | |
return angles | |
def scale(x): | |
# print(x) | |
# if abs(x[0])<0.1 and abs(x[1])<0.1: | |
# return x*5 | |
# else: | |
# return x | |
return x*3 | |
def get_proj2D_XYZ(phi, theta, gamma): | |
x = np.array([-1*np.sin(phi)*np.cos(gamma) - np.cos(phi)*np.sin(theta)*np.sin(gamma), np.sin(phi)*np.sin(gamma) - np.cos(phi)*np.sin(theta)*np.cos(gamma)]) | |
y = np.array([-1*np.cos(phi)*np.cos(gamma) + np.sin(phi)*np.sin(theta)*np.sin(gamma), np.cos(phi)*np.sin(gamma) + np.sin(phi)*np.sin(theta)*np.cos(gamma)]) | |
z = np.array([np.cos(theta)*np.sin(gamma), np.cos(theta)*np.cos(gamma)]) | |
x = scale(x) | |
y = scale(y) | |
z = scale(z) | |
return x, y, z | |
# 绘制3D坐标轴 | |
def draw_axis(ax, origin, vector, color, label=None): | |
ax.quiver(origin[0], origin[1], vector[0], vector[1], angles='xy', scale_units='xy', scale=1, color=color) | |
if label!=None: | |
ax.text(origin[0] + vector[0] * 1.1, origin[1] + vector[1] * 1.1, label, color=color, fontsize=12) | |
def figure_to_img(fig): | |
with io.BytesIO() as buf: | |
fig.savefig(buf, format='JPG', bbox_inches='tight') | |
buf.seek(0) | |
image = Image.open(buf).copy() | |
return image | |
def infer_func(img, do_rm_bkg): | |
img = Image.fromarray(img) | |
img = background_preprocess(img, do_rm_bkg) | |
angles = get_3angle(img) | |
fig, ax = plt.subplots(figsize=(8, 8)) | |
w, h = img.size | |
if h>w: | |
extent = [-5*w/h, 5*w/h, -5, 5] | |
else: | |
extent = [-5, 5, -5*h/w, 5*h/w] | |
ax.imshow(img, extent=extent, zorder=0, aspect ='auto') # extent 设置图片的显示范围 | |
origin = np.array([0, 0]) | |
# # 设置旋转角度 | |
phi = np.radians(angles[0]) | |
theta = np.radians(angles[1]) | |
gamma = np.radians(-1*angles[2]) | |
# 旋转后的向量 | |
rot_x, rot_y, rot_z = get_proj2D_XYZ(phi, theta, gamma) | |
# draw arrow | |
arrow_attr = [{'point':rot_x, 'color':'r', 'label':'front'}, | |
{'point':rot_y, 'color':'g', 'label':'right'}, | |
{'point':rot_z, 'color':'b', 'label':'top'}] | |
if phi> 45 and phi<=225: | |
order = [0,1,2] | |
elif phi > 225 and phi < 315: | |
order = [2,0,1] | |
else: | |
order = [2,1,0] | |
for i in range(3): | |
draw_axis(ax, origin, arrow_attr[order[i]]['point'], arrow_attr[order[i]]['color'], arrow_attr[order[i]]['label']) | |
# draw_axis(ax, origin, rot_y, 'g', label='right') | |
# draw_axis(ax, origin, rot_z, 'b', label='top') | |
# draw_axis(ax, origin, rot_x, 'r', label='front') | |
# 关闭坐标轴和网格 | |
ax.set_axis_off() | |
ax.grid(False) | |
# 设置坐标范围 | |
ax.set_xlim(-5, 5) | |
ax.set_ylim(-5, 5) | |
res_img = figure_to_img(fig) | |
# axis_model = "axis.obj" | |
return [res_img, float(angles[0]), float(angles[1]), float(angles[2])] | |
server = gr.Interface( | |
flagging_mode='never', | |
fn=infer_func, | |
inputs=[ | |
gr.Image(height=512, width=512, label="upload your image"), | |
gr.Checkbox(label="Remove Background", value=True) | |
], | |
outputs=[ | |
gr.Image(height=512, width=512, label="result image"), | |
# gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"), | |
gr.Textbox(lines=1, label='Azimuth(0~360°)'), | |
gr.Textbox(lines=1, label='Polar(-90~90°)'), | |
gr.Textbox(lines=1, label='Rotation(-90~90°)') | |
] | |
) | |
server.launch() | |