File size: 4,373 Bytes
43a369c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
from paths import *
import numpy as np
from vision_tower import DINOv2_MLP
from transformers import AutoImageProcessor
import torch
import os
import matplotlib.pyplot as plt
import io
from PIL import Image

from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./')
print(ckpt_path)

save_path = './'
device = 'cpu'
dino = DINOv2_MLP(
                    dino_mode   = 'large',
                    in_dim      = 1024,
                    out_dim     = 360+180+60+2,
                    evaluate    = True,
                    mask_dino   = False,
                    frozen_back = False
                ).to(device)

dino.eval()
print('model create')
dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
print('weight loaded')
val_preprocess   = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')


def get_3angle(image):
    
    # image = Image.open(image_path).convert('RGB')
    image_inputs = val_preprocess(images = image)
    image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
    with torch.no_grad():
        dino_pred = dino(image_inputs)

    gaus_ax_pred   = torch.argmax(dino_pred[:, 0:360], dim=-1)
    gaus_pl_pred   = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
    gaus_ro_pred   = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
    angles = torch.zeros(3)
    angles[0]  = gaus_ax_pred
    angles[1]  = gaus_pl_pred - 90
    angles[2]  = gaus_ro_pred - 30
    
    return angles

def scale(x):
    # print(x)
    # if abs(x[0])<0.1 and abs(x[1])<0.1:
        
    #     return x*5
    # else:
    #     return x
    return x*3
    
def get_proj2D_XYZ(phi, theta, gamma):
    x = np.array([-1*np.sin(phi)*np.cos(gamma) - np.cos(phi)*np.sin(theta)*np.sin(gamma), np.sin(phi)*np.sin(gamma) - np.cos(phi)*np.sin(theta)*np.cos(gamma)])
    y = np.array([-1*np.cos(phi)*np.cos(gamma) + np.sin(phi)*np.sin(theta)*np.sin(gamma), np.cos(phi)*np.sin(gamma) + np.sin(phi)*np.sin(theta)*np.cos(gamma)])
    z = np.array([np.cos(theta)*np.sin(gamma), np.cos(theta)*np.cos(gamma)])
    x = scale(x)
    y = scale(y)
    z = scale(z)
    return x, y, z

# 绘制3D坐标轴
def draw_axis(ax, origin, vector, color, label=None):
    ax.quiver(origin[0], origin[1], vector[0], vector[1], angles='xy', scale_units='xy', scale=1, color=color)
    if label!=None:
        ax.text(origin[0] + vector[0] * 1.1, origin[1] + vector[1] * 1.1, label, color=color, fontsize=12)

def figure_to_img(fig):
    with io.BytesIO() as buf:
        fig.savefig(buf, format='JPG', bbox_inches='tight')
        buf.seek(0)
        image = Image.open(buf).copy()
    return image

# def generate_mutimodal(title, context, img):
#     return f"Title:{title}\nContext:{context}\n...{img}"

def generate_mutimodal(img):
    angles = get_3angle(img)
    
    fig, ax = plt.subplots(figsize=(8, 8))

    h, w, c = img.shape
    if h>w:
        extent = [-5*w/h, 5*w/h, -5, 5]
    else:
        extent = [-5, 5, -5*h/w, 5*h/w]
    ax.imshow(img, extent=extent, zorder=0, aspect ='auto')  # extent 设置图片的显示范围

    origin = np.array([0, 0])

    # # 设置旋转角度
    phi   = np.radians(angles[0])
    theta = np.radians(angles[1])
    gamma = np.radians(-1*angles[2])

    # 旋转后的向量
    rot_x, rot_y, rot_z = get_proj2D_XYZ(phi, theta, gamma)

    draw_axis(ax, origin, rot_y, 'g')
    draw_axis(ax, origin, rot_z, 'b')
    draw_axis(ax, origin, rot_x, 'r')

    # 关闭坐标轴和网格
    ax.set_axis_off()
    ax.grid(False)

    # 设置坐标范围
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)
    
    res_img = figure_to_img(fig)
    # axis_model = "axis.obj"
    return [res_img, float(angles[0]), float(angles[1]), float(angles[2])]

server = gr.Interface(
    flagging_mode='never',
    fn=generate_mutimodal, 
    inputs=[
        gr.Image(height=512, width=512, label="upload your image")
    ], 
    outputs=[
        gr.Image(height=512, width=512, label="result image"),
        # gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0],  label="3D Model"),
        gr.Textbox(lines=1, label='Azimuth(0~360°)'),
        gr.Textbox(lines=1, label='Polar(-90~90°)'),
        gr.Textbox(lines=1, label='Rotation(-90~90°)')
    ]
)

server.launch()