File size: 3,211 Bytes
3ec9877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f398fe3
 
 
3ec9877
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import numpy as np
import gradio as gr
import plotly.express as px
import plotly.graph_objects as go
from sklearn.decomposition import PCA
from torchvision import transforms as T
from sklearn.preprocessing import MinMaxScaler


device = "cuda" if torch.cuda.is_available() else "cpu"

dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
dino.eval()
dino.to(device)

pca = PCA(n_components=3)
scaler = MinMaxScaler(clip=True)

def plot_img(img_array: np.array) -> go.Figure:
    fig = px.imshow(img_array)
    fig.update_layout(
        xaxis=dict(showticklabels=False),
        yaxis=dict(showticklabels=False)
    )

    return fig


def app_fn(
        img: np.ndarray, 
        threshold: float, 
        object_larger_than_bg: bool
    ) -> go.Figure:
    IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

    patch_h = 40
    patch_w = 40

    transform = T.Compose([
        T.Resize((14 * patch_h, 14 * patch_w)),
        T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
    ])

    img = torch.from_numpy(img).type(torch.float).permute(2, 0, 1) / 255
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        out = dino.forward_features(img_tensor)

    features = out["x_prenorm"][:, 1:, :]
    features = features.squeeze(0)
    features = features.cpu().numpy()

    pca_features = pca.fit_transform(features)
    pca_features = scaler.fit_transform(pca_features)

    if object_larger_than_bg:
        pca_features_bg = pca_features[:, 0] > threshold
    else: 
        pca_features_bg = pca_features[:, 0] < threshold
        
    pca_features_fg = ~pca_features_bg

    pca_features_fg_seg = pca.fit_transform(features[pca_features_fg])

    pca_features_fg_seg = scaler.fit_transform(pca_features_fg_seg)

    pca_features_rgb = np.zeros((patch_h * patch_w, 3))
    pca_features_rgb[pca_features_bg] = 0
    pca_features_rgb[pca_features_fg] = pca_features_fg_seg
    pca_features_rgb = pca_features_rgb.reshape(patch_h, patch_w, 3)

    
    fig_pca = plot_img(pca_features_rgb)

    return fig_pca

if __name__=="__main__":
    title = "DINOv2"
    with gr.Blocks(title=title) as demo:
        gr.Markdown(f"# {title}")
        gr.Markdown(
            """
            """
        )
        with gr.Row():
            threshold = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.05, label="Threshold")
            object_larger_than_bg = gr.Checkbox(label="Object Larger than Background", value=False)
        btn = gr.Button(label="Visualize")
        with gr.Row():
            img = gr.Image()
            fig_pca = gr.Plot(label="PCA Features")
        
        btn.click(fn=app_fn, inputs=[img, threshold, object_larger_than_bg], outputs=[fig_pca])
        examples = gr.Examples(
            examples=[
                ["assets/photo-1.jpg", 0.6, True],
                ["assets/photo-2.jpg", 0.7, True],
                ["assets/photo-3.jpg", 0.8, False]
            ],
            inputs=[img, threshold, object_larger_than_bg],
            outputs=[fig_pca],
            fn=app_fn,
            cache_examples=True
        )

    demo.queue(max_size=5).launch()