Spaces:
Sleeping
Sleeping
import pickle | |
import os | |
from sklearn.neighbors import NearestNeighbors | |
import numpy as np | |
num_nn = 20 | |
import gradio as gr | |
from PIL import Image | |
data_root = '/dccstor/elishc1/datasets/DomainNet' | |
feat_dir = 'brad_feats' | |
domains = ['real', 'painting', 'clipart', 'sketch'] | |
shots = '-1' | |
search_domain = 'all' | |
num_results_per_domain = 5 | |
src_data_dict = {} | |
if search_domain == 'all': | |
for d in domains: | |
with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp: | |
src_data = pickle.load(fp) | |
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, | |
algorithm='auto', n_jobs=-1).fit(src_data[1]) | |
src_data_dict[d] = (src_data,src_nn_fit) | |
else: | |
with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as | |
fp: | |
src_data = pickle.load(fp) | |
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, | |
algorithm='auto', n_jobs=-1).fit(src_data[1]) | |
src_data_dict[search_domain] = (src_data,src_nn_fit) | |
dst_data_dict = {} | |
for d in domains: | |
with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp: | |
dst_data_dict[d] = pickle.load(fp) | |
def query(query_index, query_domain): | |
dst_data = dst_data_dict[query_domain] | |
dst_img_path = os.path.join(data_root, dst_data[0][query_index]) | |
img_paths = [dst_img_path] | |
q_cl = dst_img_path.split('/')[-2] | |
captions = [f'Query: {q_cl}'] | |
for s_domain, s_data in src_data_dict.items(): | |
_, top_n_matches_ids = | |
s_data[1].kneighbors(dst_data[1][query_index:query_index+1]) | |
top_n_labels = s_data[0][2][top_n_matches_ids][0] | |
src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in | |
top_n_matches_ids[0]] | |
img_paths += src_img_pths | |
for p in src_img_pths: | |
src_cl = p.split('/')[-2] | |
src_file = p.split('/')[-1] | |
captions.append(src_cl) | |
return tuple([Image.open(p) for p in img_paths])+ tuple(captions) | |
try: | |
demo.close() | |
except: | |
pass | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown('## Select Query Domain: ') | |
domain_drop = gr.Dropdown(domains) | |
# domain_select_button = gr.Button("Select Domain") | |
slider = gr.Slider(0, 1000) | |
image_button = gr.Button("Run") | |
gr.Markdown('# Query Image') | |
src_cap = gr.Label() | |
src_img = gr.Image() | |
out_images = [] | |
out_captions = [] | |
for d in domains: | |
gr.Markdown(f'# {d.title()} Domain Images') | |
with gr.Row(): | |
for _ in range(num_results_per_domain): | |
with gr.Column(): | |
out_captions.append(gr.Label()) | |
out_images.append(gr.Image()) | |
image_button.click(query, inputs=[slider, domain_drop], | |
outputs=[src_img]+out_images +[src_cap]+ out_captions) | |
demo.launch(share=True) | |