aarbelle commited on
Commit
24cb51e
1 Parent(s): fef403e

add requirements and app.py

Browse files
Files changed (2) hide show
  1. app.py +87 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+ from sklearn.neighbors import NearestNeighbors
4
+ import numpy as np
5
+ num_nn = 20
6
+ import gradio as gr
7
+ from PIL import Image
8
+
9
+ data_root = '/dccstor/elishc1/datasets/DomainNet'
10
+ feat_dir = 'brad_feats'
11
+ domains = ['real', 'painting', 'clipart', 'sketch']
12
+ shots = '-1'
13
+ search_domain = 'all'
14
+ num_results_per_domain = 5
15
+ src_data_dict = {}
16
+ if search_domain == 'all':
17
+ for d in domains:
18
+ with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
19
+ src_data = pickle.load(fp)
20
+
21
+ src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain,
22
+ algorithm='auto', n_jobs=-1).fit(src_data[1])
23
+ src_data_dict[d] = (src_data,src_nn_fit)
24
+ else:
25
+
26
+ with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as
27
+ fp:
28
+ src_data = pickle.load(fp)
29
+ src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain,
30
+ algorithm='auto', n_jobs=-1).fit(src_data[1])
31
+ src_data_dict[search_domain] = (src_data,src_nn_fit)
32
+
33
+ dst_data_dict = {}
34
+ for d in domains:
35
+ with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
36
+ dst_data_dict[d] = pickle.load(fp)
37
+
38
+ def query(query_index, query_domain):
39
+ dst_data = dst_data_dict[query_domain]
40
+ dst_img_path = os.path.join(data_root, dst_data[0][query_index])
41
+ img_paths = [dst_img_path]
42
+ q_cl = dst_img_path.split('/')[-2]
43
+ captions = [f'Query: {q_cl}']
44
+ for s_domain, s_data in src_data_dict.items():
45
+ _, top_n_matches_ids =
46
+ s_data[1].kneighbors(dst_data[1][query_index:query_index+1])
47
+ top_n_labels = s_data[0][2][top_n_matches_ids][0]
48
+ src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in
49
+ top_n_matches_ids[0]]
50
+ img_paths += src_img_pths
51
+
52
+ for p in src_img_pths:
53
+ src_cl = p.split('/')[-2]
54
+ src_file = p.split('/')[-1]
55
+ captions.append(src_cl)
56
+ return tuple([Image.open(p) for p in img_paths])+ tuple(captions)
57
+ try:
58
+ demo.close()
59
+ except:
60
+ pass
61
+ demo = gr.Blocks()
62
+ with demo:
63
+ gr.Markdown('## Select Query Domain: ')
64
+ domain_drop = gr.Dropdown(domains)
65
+ # domain_select_button = gr.Button("Select Domain")
66
+ slider = gr.Slider(0, 1000)
67
+ image_button = gr.Button("Run")
68
+
69
+ gr.Markdown('# Query Image')
70
+ src_cap = gr.Label()
71
+ src_img = gr.Image()
72
+
73
+
74
+ out_images = []
75
+ out_captions = []
76
+ for d in domains:
77
+ gr.Markdown(f'# {d.title()} Domain Images')
78
+ with gr.Row():
79
+ for _ in range(num_results_per_domain):
80
+ with gr.Column():
81
+ out_captions.append(gr.Label())
82
+ out_images.append(gr.Image())
83
+
84
+ image_button.click(query, inputs=[slider, domain_drop],
85
+ outputs=[src_img]+out_images +[src_cap]+ out_captions)
86
+ demo.launch(share=True)
87
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy==1.20.1
2
+ Pillow==8.2.0
3
+ scikit-learn==0.24.1
4
+