Orient-Anything / app.py
zhang-ziang
axis label
0f72f6a
raw
history blame
6.91 kB
import gradio as gr
from paths import *
import numpy as np
from vision_tower import DINOv2_MLP
from transformers import AutoImageProcessor
import torch
import os
import matplotlib.pyplot as plt
import io
from PIL import Image
import rembg
from typing import Any
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True)
print(ckpt_path)
save_path = './'
device = 'cpu'
dino = DINOv2_MLP(
dino_mode = 'large',
in_dim = 1024,
out_dim = 360+180+60+2,
evaluate = True,
mask_dino = False,
frozen_back = False
).to(device)
dino.eval()
print('model create')
dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
print('weight loaded')
val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
def background_preprocess(input_image, do_remove_background):
rembg_session = rembg.new_session() if do_remove_background else None
if do_remove_background:
input_image = remove_background(input_image, rembg_session)
input_image = resize_foreground(input_image, 0.85)
return input_image
def resize_foreground(
image: Image,
ratio: float,
) -> Image:
image = np.array(image)
assert image.shape[-1] == 4
alpha = np.where(image[..., 3] > 0)
y1, y2, x1, x2 = (
alpha[0].min(),
alpha[0].max(),
alpha[1].min(),
alpha[1].max(),
)
# crop the foreground
fg = image[y1:y2, x1:x2]
# pad to square
size = max(fg.shape[0], fg.shape[1])
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
new_image = np.pad(
fg,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
# compute padding according to the ratio
new_size = int(new_image.shape[0] / ratio)
# pad to size, double side
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
new_image = np.pad(
new_image,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
new_image = Image.fromarray(new_image)
return new_image
def remove_background(image: Image,
rembg_session: Any = None,
force: bool = False,
**rembg_kwargs,
) -> Image:
do_remove = True
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
do_remove = False
do_remove = do_remove or force
if do_remove:
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
return image
def get_3angle(image):
# image = Image.open(image_path).convert('RGB')
image_inputs = val_preprocess(images = image)
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
with torch.no_grad():
dino_pred = dino(image_inputs)
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
angles = torch.zeros(3)
angles[0] = gaus_ax_pred
angles[1] = gaus_pl_pred - 90
angles[2] = gaus_ro_pred - 30
return angles
def scale(x):
# print(x)
# if abs(x[0])<0.1 and abs(x[1])<0.1:
# return x*5
# else:
# return x
return x*3
def get_proj2D_XYZ(phi, theta, gamma):
x = np.array([-1*np.sin(phi)*np.cos(gamma) - np.cos(phi)*np.sin(theta)*np.sin(gamma), np.sin(phi)*np.sin(gamma) - np.cos(phi)*np.sin(theta)*np.cos(gamma)])
y = np.array([-1*np.cos(phi)*np.cos(gamma) + np.sin(phi)*np.sin(theta)*np.sin(gamma), np.cos(phi)*np.sin(gamma) + np.sin(phi)*np.sin(theta)*np.cos(gamma)])
z = np.array([np.cos(theta)*np.sin(gamma), np.cos(theta)*np.cos(gamma)])
x = scale(x)
y = scale(y)
z = scale(z)
return x, y, z
# 绘制3D坐标轴
def draw_axis(ax, origin, vector, color, label=None):
ax.quiver(origin[0], origin[1], vector[0], vector[1], angles='xy', scale_units='xy', scale=1, color=color)
if label!=None:
ax.text(origin[0] + vector[0] * 1.1, origin[1] + vector[1] * 1.1, label, color=color, fontsize=12)
def figure_to_img(fig):
with io.BytesIO() as buf:
fig.savefig(buf, format='JPG', bbox_inches='tight')
buf.seek(0)
image = Image.open(buf).copy()
return image
def infer_func(img, do_rm_bkg):
img = Image.fromarray(img)
img = background_preprocess(img, do_rm_bkg)
angles = get_3angle(img)
fig, ax = plt.subplots(figsize=(8, 8))
w, h = img.size
if h>w:
extent = [-5*w/h, 5*w/h, -5, 5]
else:
extent = [-5, 5, -5*h/w, 5*h/w]
ax.imshow(img, extent=extent, zorder=0, aspect ='auto') # extent 设置图片的显示范围
origin = np.array([0, 0])
# # 设置旋转角度
phi = np.radians(angles[0])
theta = np.radians(angles[1])
gamma = np.radians(-1*angles[2])
# 旋转后的向量
rot_x, rot_y, rot_z = get_proj2D_XYZ(phi, theta, gamma)
# draw arrow
arrow_attr = [{'point':rot_x, 'color':'r', 'label':'front'},
{'point':rot_y, 'color':'g', 'label':'right'},
{'point':rot_z, 'color':'b', 'label':'top'}]
if phi> 45 and phi<=225:
order = [0,1,2]
elif phi > 225 and phi < 315:
order = [2,0,1]
else:
order = [2,1,0]
for i in range(3):
draw_axis(ax, origin, arrow_attr[order[i]]['point'], arrow_attr[order[i]]['color'], arrow_attr[order[i]]['label'])
# draw_axis(ax, origin, rot_y, 'g', label='right')
# draw_axis(ax, origin, rot_z, 'b', label='top')
# draw_axis(ax, origin, rot_x, 'r', label='front')
# 关闭坐标轴和网格
ax.set_axis_off()
ax.grid(False)
# 设置坐标范围
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
res_img = figure_to_img(fig)
# axis_model = "axis.obj"
return [res_img, float(angles[0]), float(angles[1]), float(angles[2])]
server = gr.Interface(
flagging_mode='never',
fn=infer_func,
inputs=[
gr.Image(height=512, width=512, label="upload your image"),
gr.Checkbox(label="Remove Background", value=True)
],
outputs=[
gr.Image(height=512, width=512, label="result image"),
# gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
gr.Textbox(lines=1, label='Azimuth(0~360°)'),
gr.Textbox(lines=1, label='Polar(-90~90°)'),
gr.Textbox(lines=1, label='Rotation(-90~90°)')
]
)
server.launch()