File size: 2,819 Bytes
24cb51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pickle
import os
from sklearn.neighbors import NearestNeighbors
import numpy as np
num_nn = 20
import gradio as gr
from PIL import Image

data_root = '/dccstor/elishc1/datasets/DomainNet'
feat_dir = 'brad_feats' 
domains = ['real', 'painting', 'clipart', 'sketch']
shots = '-1'
search_domain = 'all'
num_results_per_domain = 5
src_data_dict = {}
if search_domain == 'all':
    for d in domains:
        with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
            src_data = pickle.load(fp)

            src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, 
algorithm='auto', n_jobs=-1).fit(src_data[1])
            src_data_dict[d] = (src_data,src_nn_fit)
else:

    with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as 
fp:
        src_data = pickle.load(fp)
        src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, 
algorithm='auto', n_jobs=-1).fit(src_data[1])
        src_data_dict[search_domain] = (src_data,src_nn_fit)

dst_data_dict = {}
for d in domains:
    with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
        dst_data_dict[d] = pickle.load(fp)

def query(query_index, query_domain):
    dst_data = dst_data_dict[query_domain]
    dst_img_path = os.path.join(data_root, dst_data[0][query_index])
    img_paths = [dst_img_path]
    q_cl = dst_img_path.split('/')[-2]
    captions = [f'Query: {q_cl}']
    for s_domain, s_data in src_data_dict.items():  
        _, top_n_matches_ids = 
s_data[1].kneighbors(dst_data[1][query_index:query_index+1])
        top_n_labels = s_data[0][2][top_n_matches_ids][0]
        src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in 
top_n_matches_ids[0]]
        img_paths += src_img_pths

        for p in src_img_pths:
            src_cl = p.split('/')[-2]
            src_file =  p.split('/')[-1]
            captions.append(src_cl)    
    return tuple([Image.open(p) for p in img_paths])+ tuple(captions)
try:
    demo.close()
except:
    pass
demo = gr.Blocks()
with demo:
    gr.Markdown('## Select Query Domain: ')
    domain_drop = gr.Dropdown(domains)    
#     domain_select_button = gr.Button("Select Domain")
    slider = gr.Slider(0, 1000)
    image_button = gr.Button("Run")

    gr.Markdown('# Query Image')
    src_cap = gr.Label()
    src_img = gr.Image()

    
    out_images = []
    out_captions = []
    for d in domains:
        gr.Markdown(f'# {d.title()} Domain Images')
        with gr.Row():
            for _ in range(num_results_per_domain):
                with gr.Column():
                    out_captions.append(gr.Label())
                    out_images.append(gr.Image())

    image_button.click(query, inputs=[slider, domain_drop], 
outputs=[src_img]+out_images +[src_cap]+ out_captions)
demo.launch(share=True)