Spaces:
Running
Running
File size: 2,735 Bytes
43a369c 00fe360 43a369c 00fe360 738bdfa 43a369c 00fe360 43a369c 00fe360 43a369c 00fe360 43a369c 00fe360 43a369c 6965bae 0366edb 6965bae 0366edb 00fe360 6965bae 0366edb 00fe360 43a369c 738bdfa 00fe360 43a369c 74503df 43a369c b03b419 43a369c b03b419 6965bae 43a369c 00fe360 43a369c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
from paths import *
from vision_tower import DINOv2_MLP
from transformers import AutoImageProcessor
import torch
from inference import *
from utils import *
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="Viglong/Orient-Anything", filename="croplargeEX2/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True)
print(ckpt_path)
save_path = './'
device = 'cpu'
dino = DINOv2_MLP(
dino_mode = 'large',
in_dim = 1024,
out_dim = 360+180+180+2,
evaluate = True,
mask_dino = False,
frozen_back = False
)
dino.eval()
print('model create')
dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
dino = dino.to(device)
print('weight loaded')
val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
def infer_func(img, do_rm_bkg, do_infer_aug):
origin_img = Image.fromarray(img)
if do_infer_aug:
rm_bkg_img = background_preprocess(origin_img, True)
angles = get_3angle_infer_aug(origin_img, rm_bkg_img, dino, val_preprocess, device)
else:
rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
angles = get_3angle(rm_bkg_img, dino, val_preprocess, device)
phi = np.radians(angles[0])
theta = np.radians(angles[1])
gamma = angles[2]
confidence = float(angles[3])
if confidence > 0.5:
render_axis = render_3D_axis(phi, theta, gamma)
res_img = overlay_images_with_scaling(render_axis, rm_bkg_img)
else:
res_img = img
# axis_model = "axis.obj"
return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]
server = gr.Interface(
flagging_mode='never',
fn=infer_func,
inputs=[
gr.Image(height=512, width=512, label="upload your image"),
gr.Checkbox(label="Remove Background", value=True),
gr.Checkbox(label="Inference time augmentation", value=False)
],
outputs=[
gr.Image(height=512, width=512, label="result image"),
# gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
gr.Textbox(lines=1, label='Azimuth(0~360°) represents the position of the viewer in the xy plane'),
gr.Textbox(lines=1, label='Polar(-90~90°) indicating the height at which the viewer is located'),
gr.Textbox(lines=1, label='Rotation(-90~90°) represents the angle of rotation of the viewer'),
gr.Textbox(lines=1, label='Confidence(0~1) indicating whether the object has a meaningful orientation')
]
)
server.launch()
|