import glob import gradio as gr import pandas as pd import faiss import clip import torch from huggingface_hub import hf_hub_download, snapshot_download title = r"""

🔍 Search Similar Text/Image in the Dataset

""" description = r""" Find text or images similar to your query text with this demo. Currently, it supports text search only.
In this demo, we use a subset of [danbooru22](https://huggingface.co/datasets/animelover/danbooru2022) or [DiffusionDB](https://huggingface.co/datasets/poloclub/diffusiondb) instead of [LAION](https://laion.ai/blog/laion-400-open-dataset/) because LAION is currently not available.
The content will be updated to include image search once LAION is available. The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval) and [autofaiss](https://github.com/criteo/autofaiss) """ # From local file # INDEX_DIR = "dataset/diffusiondb/text_index_folder" # IND = faiss.read_index(f"{INDEX_DIR}/text.index") # TEXT_LIST = pd.concat( # pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet") # )['caption'].tolist() def download_all_index(dataset_dict): for k in dataset_dict: load_faiss_index(k) def load_faiss_index(dataset): index_dir = "data/faiss_index" dataset = DATASET_NAME[dataset] hf_hub_download( repo_id="Eun02/text_image_faiss_index", subfolder=dataset, filename="text.index", repo_type="dataset", local_dir=index_dir, ) # Download text file snapshot_download( repo_id="Eun02/text_image_faiss_index", allow_patterns=f"{dataset}/*.parquet", repo_type="dataset", local_dir=index_dir, ) index = faiss.read_index(f"{index_dir}/{dataset}/text.index") text_list = pd.concat( pd.read_parquet(file) for file in sorted(glob.glob(f"{index_dir}/{dataset}/metadata/*.parquet")) )['caption'].tolist() return index, text_list def change_index(dataset): global INDEX, TEXT_LIST, PREV_DATASET if PREV_DATASET != dataset: gr.Info("Load index...") INDEX, TEXT_LIST = load_faiss_index(dataset) PREV_DATASET = dataset gr.Info("Done!!") return None @torch.inference_mode def get_emb(text, device="cpu"): text_tokens = clip.tokenize([text], truncate=True) text_features = CLIP_MODEL.encode_text(text_tokens.to(device)) text_features /= text_features.norm(dim=-1, keepdim=True) text_embeddings = text_features.cpu().numpy().astype('float32') return text_embeddings @torch.inference_mode def search_text(top_k, show_score, numbering_prefix, output_file, query_text): if query_text is None or query_text == "": raise gr.Error("Query text is missing") text_embeddings = get_emb(query_text, device) scores, retrieved_texts = INDEX.search(text_embeddings, top_k) scores, retrieved_texts = scores[0], retrieved_texts[0] result_list = [] for score, ind in zip(scores, retrieved_texts): item_str = TEXT_LIST[ind].strip() if item_str == "": continue if (item_str, score) not in result_list: result_list.append((item_str, score)) # Postprocessing text result_str = "" for count, (item_str, score) in enumerate(result_list): if numbering_prefix: item_str = f"###################### {count+1} ######################\n {item_str}" if show_score: item_str += f", {score:0.2f}" result_str += f"{item_str}\n" # file_name = query_text.replace(" ", "_") # if show_score: # file_name += "_score" output_path = None if output_file: file_name = "output" output_path = f"./{file_name}.txt" with open(output_path, "w") as f: f.writelines(result_str) return result_str, output_path # Load CLIP model device = "cpu" CLIP_MODEL, _ = clip.load("ViT-B/32", device=device) # Dataset DATASET_NAME = { "danbooru22": "booru22_000-300", "DiffusionDB": "diffusiondb", } DEFAULT_DATASET = "danbooru22" PREV_DATASET = "danbooru22" # Download needed index download_all_index(DATASET_NAME) # Load default index INDEX, TEXT_LIST = load_faiss_index(DEFAULT_DATASET) with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description) with gr.Row(): dataset = gr.Dropdown(label="dataset", choices=["danbooru22", "DiffusionDB"], value=DEFAULT_DATASET) top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8) with gr.Column(): show_score = gr.Checkbox(label="Show score", value=False) numbering_prefix = gr.Checkbox(label="Add numbering prefix", value=True) output_file = gr.Checkbox(label="Return text file", value=True) query_text = gr.Textbox(label="query text") btn = gr.Button() result_text = gr.Textbox(label="retrieved text", interactive=False) result_file = gr.File(label="output file", visible=True) #dataset.change(change_index, dataset, None) btn.click( fn=change_index, inputs=[dataset], outputs=[result_text], ).success( fn=search_text, inputs=[top_k, show_score, numbering_prefix, output_file, query_text], outputs=[result_text, result_file], ) demo.launch()