Spaces:
Runtime error
Runtime error
File size: 5,385 Bytes
1a03e08 d9c0937 1a03e08 d9c0937 1a03e08 d9c0937 a363bd1 ec82f37 a363bd1 ec82f37 1a03e08 d9c0937 bd2ecfe 1a03e08 d9c0937 1a03e08 d9c0937 1a03e08 d9c0937 1a03e08 d9c0937 1a03e08 d9c0937 1a03e08 d9c0937 f9767c2 d9c0937 1a03e08 d9c0937 1a03e08 d9c0937 1a03e08 d9c0937 1a03e08 d9c0937 1a03e08 d9c0937 bd2ecfe d9c0937 1a03e08 d9c0937 1a03e08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import glob
import gradio as gr
import pandas as pd
import faiss
import clip
import torch
from huggingface_hub import hf_hub_download, snapshot_download
title = r"""
<h1 align="center" id="space-title"> 🔍 Search Similar Text/Image in the Dataset</h1>
"""
description = r"""
Find text or images similar to your query text with this demo. Currently, it supports text search only.<br>
In this demo, we use a subset of [danbooru22](https://huggingface.co/datasets/animelover/danbooru2022) or [DiffusionDB](https://huggingface.co/datasets/poloclub/diffusiondb) instead of [LAION](https://laion.ai/blog/laion-400-open-dataset/) because LAION is currently not available.
<br>
The content will be updated to include image search once LAION is available.
The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval) and [autofaiss](https://github.com/criteo/autofaiss)
"""
# From local file
# INDEX_DIR = "dataset/diffusiondb/text_index_folder"
# IND = faiss.read_index(f"{INDEX_DIR}/text.index")
# TEXT_LIST = pd.concat(
# pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet")
# )['caption'].tolist()
def download_all_index(dataset_dict):
for k in dataset_dict:
load_faiss_index(k)
def load_faiss_index(dataset):
index_dir = "data/faiss_index"
dataset = DATASET_NAME[dataset]
hf_hub_download(
repo_id="Eun02/text_image_faiss_index",
subfolder=dataset,
filename="text.index",
repo_type="dataset",
local_dir=index_dir,
)
# Download text file
snapshot_download(
repo_id="Eun02/text_image_faiss_index",
allow_patterns=f"{dataset}/*.parquet",
repo_type="dataset",
local_dir=index_dir,
)
index = faiss.read_index(f"{index_dir}/{dataset}/text.index")
text_list = pd.concat(
pd.read_parquet(file) for file in sorted(glob.glob(f"{index_dir}/{dataset}/metadata/*.parquet"))
)['caption'].tolist()
return index, text_list
def change_index(dataset):
global INDEX, TEXT_LIST, PREV_DATASET
if PREV_DATASET != dataset:
gr.Info("Load index...")
INDEX, TEXT_LIST = load_faiss_index(dataset)
PREV_DATASET = dataset
gr.Info("Done!!")
return None
@torch.inference_mode
def get_emb(text, device="cpu"):
text_tokens = clip.tokenize([text], truncate=True)
text_features = CLIP_MODEL.encode_text(text_tokens.to(device))
text_features /= text_features.norm(dim=-1, keepdim=True)
text_embeddings = text_features.cpu().numpy().astype('float32')
return text_embeddings
@torch.inference_mode
def search_text(top_k, show_score, numbering_prefix, output_file, query_text):
if query_text is None or query_text == "":
raise gr.Error("Query text is missing")
text_embeddings = get_emb(query_text, device)
scores, retrieved_texts = INDEX.search(text_embeddings, top_k)
scores, retrieved_texts = scores[0], retrieved_texts[0]
result_list = []
for score, ind in zip(scores, retrieved_texts):
item_str = TEXT_LIST[ind].strip()
if item_str == "":
continue
if (item_str, score) not in result_list:
result_list.append((item_str, score))
# Postprocessing text
result_str = ""
for count, (item_str, score) in enumerate(result_list):
if numbering_prefix:
item_str = f"###################### {count+1} ######################\n {item_str}"
if show_score:
item_str += f", {score:0.2f}"
result_str += f"{item_str}\n"
# file_name = query_text.replace(" ", "_")
# if show_score:
# file_name += "_score"
output_path = None
if output_file:
file_name = "output"
output_path = f"./{file_name}.txt"
with open(output_path, "w") as f:
f.writelines(result_str)
return result_str, output_path
# Load CLIP model
device = "cpu"
CLIP_MODEL, _ = clip.load("ViT-B/32", device=device)
# Dataset
DATASET_NAME = {
"danbooru22": "booru22_000-300",
"DiffusionDB": "diffusiondb",
}
DEFAULT_DATASET = "danbooru22"
PREV_DATASET = "danbooru22"
# Download needed index
download_all_index(DATASET_NAME)
# Load default index
INDEX, TEXT_LIST = load_faiss_index(DEFAULT_DATASET)
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
dataset = gr.Dropdown(label="dataset", choices=["danbooru22", "DiffusionDB"], value=DEFAULT_DATASET)
top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8)
with gr.Column():
show_score = gr.Checkbox(label="Show score", value=False)
numbering_prefix = gr.Checkbox(label="Add numbering prefix", value=True)
output_file = gr.Checkbox(label="Return text file", value=True)
query_text = gr.Textbox(label="query text")
btn = gr.Button()
result_text = gr.Textbox(label="retrieved text", interactive=False)
result_file = gr.File(label="output file", visible=True)
#dataset.change(change_index, dataset, None)
btn.click(
fn=change_index,
inputs=[dataset],
outputs=[result_text],
).success(
fn=search_text,
inputs=[top_k, show_score, numbering_prefix, output_file, query_text],
outputs=[result_text, result_file],
)
demo.launch()
|