File size: 2,949 Bytes
43a369c
 
e1a7ab3
43a369c
 
 
00fe360
738bdfa
43a369c
 
00fe360
43a369c
 
 
 
 
 
 
00fe360
43a369c
 
 
00fe360
43a369c
 
 
 
00fe360
43a369c
 
 
6965bae
0366edb
6965bae
0366edb
00fe360
6965bae
0366edb
00fe360
43a369c
 
 
738bdfa
00fe360
 
 
 
 
 
43a369c
 
74503df
43a369c
e1a7ab3
 
 
ed2605d
43a369c
 
e1a7ab3
 
43a369c
b03b419
6965bae
 
43a369c
 
 
 
00fe360
 
 
 
43a369c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from paths import *
import os
from vision_tower import DINOv2_MLP
from transformers import AutoImageProcessor
import torch
from inference import *
from utils import *

from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="Viglong/Orient-Anything", filename="croplargeEX2/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True)
print(ckpt_path)

save_path = './'
device = 'cpu'
dino = DINOv2_MLP(
                    dino_mode   = 'large',
                    in_dim      = 1024,
                    out_dim     = 360+180+180+2,
                    evaluate    = True,
                    mask_dino   = False,
                    frozen_back = False
                )

dino.eval()
print('model create')
dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
dino = dino.to(device)
print('weight loaded')
val_preprocess   = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')

def infer_func(img, do_rm_bkg, do_infer_aug):
    origin_img = Image.fromarray(img)
    if do_infer_aug:
        rm_bkg_img = background_preprocess(origin_img, True)
        angles = get_3angle_infer_aug(origin_img, rm_bkg_img, dino, val_preprocess, device)
    else:
        rm_bkg_img = background_preprocess(origin_img, do_rm_bkg)
        angles = get_3angle(rm_bkg_img, dino, val_preprocess, device)
    
    phi   = np.radians(angles[0])
    theta = np.radians(angles[1])
    gamma = angles[2]
    confidence = float(angles[3])
    if confidence > 0.5:
        render_axis = render_3D_axis(phi, theta, gamma)
        res_img = overlay_images_with_scaling(render_axis, rm_bkg_img)
    else:
        res_img = img
    
    # axis_model = "axis.obj"
    return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)]

example_files = os.listdir('examples')
example_files.sort()
example_files = [[os.path.join('examples', filename), None, None] for filename in example_files]
print(example_files)
server = gr.Interface(
    flagging_mode='never',
    fn=infer_func,
    examples=example_files,
    inputs=[
        gr.Image(height=512, width=512, label="upload your image"),
        gr.Checkbox(label="Remove Background", value=True),
        gr.Checkbox(label="Inference time augmentation", value=False)
    ], 
    outputs=[
        gr.Image(height=512, width=512, label="result image"),
        # gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0],  label="3D Model"),
        gr.Textbox(lines=1, label='Azimuth(0~360°) represents the position of the viewer in the xy plane'),
        gr.Textbox(lines=1, label='Polar(-90~90°) indicating the height at which the viewer is located'),
        gr.Textbox(lines=1, label='Rotation(-90~90°) represents the angle of rotation of the viewer'),
        gr.Textbox(lines=1, label='Confidence(0~1) indicating whether the object has a meaningful orientation')
    ]
)

server.launch()