SmilingWolf's picture
Update index to danbooru dataset v3
03d2c4c
import argparse
import json
import faiss
import gradio as gr
import numpy as np
import requests
from imgutils.tagging import wd14
TITLE = "## Danbooru Explorer"
DESCRIPTION = """
Image similarity-based retrieval tool using:
- [SmilingWolf/wd-swinv2-tagger-v3](https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3) as feature extractor
- [dghs-imgutils](https://github.com/deepghs/imgutils) for feature extraction
- [Faiss](https://github.com/facebookresearch/faiss) and [autofaiss](https://github.com/criteo/autofaiss) for indexing
Also, check out [SmilingWolf/danbooru2022_embeddings_playground](https://huggingface.co/spaces/SmilingWolf/danbooru2022_embeddings_playground) for a similar space with experimental support for text input combined with image input.
"""
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true")
return parser.parse_args()
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
headers = {"User-Agent": "image_similarity_tool"}
ratings_to_letters = {
"General": "g",
"Sensitive": "s",
"Questionable": "q",
"Explicit": "e",
}
acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings]
image_url = f"https://danbooru.donmai.us/posts/{image_id}.json"
if api_username != "" and api_key != "":
image_url = f"{image_url}?api_key={api_key}&login={api_username}"
r = requests.get(image_url, headers=headers)
if r.status_code != 200:
return None
content = json.loads(r.text)
image_url = content["large_file_url"] if "large_file_url" in content else None
image_url = image_url if content["rating"] in acceptable_ratings else None
return image_url
class SimilaritySearcher:
def __init__(self):
self.images_ids = np.load("index/cosine_ids.npy")
self.knn_index = faiss.read_index("index/cosine_knn.index")
config = json.loads(open("index/cosine_infos.json").read())["index_param"]
faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
def predict(
self,
img_input,
selected_ratings,
n_neighbours,
api_username,
api_key,
):
embeddings = wd14.get_wd14_tags(
img_input,
model_name="SwinV2_v3",
fmt=("embedding"),
)
embeddings = np.expand_dims(embeddings, 0)
faiss.normalize_L2(embeddings)
dists, indexes = self.knn_index.search(embeddings, k=n_neighbours)
neighbours_ids = self.images_ids[indexes][0]
neighbours_ids = [int(x) for x in neighbours_ids]
captions = []
image_urls = []
for image_id, dist in zip(neighbours_ids, dists[0]):
current_url = danbooru_id_to_url(
image_id,
selected_ratings,
api_username,
api_key,
)
if current_url is not None:
image_urls.append(current_url)
captions.append(f"{image_id}/{dist:.2f}")
return list(zip(image_urls, captions))
def main():
args = parse_args()
searcher = SimilaritySearcher()
with gr.Blocks() as demo:
gr.Markdown(TITLE)
gr.Markdown(DESCRIPTION)
with gr.Row():
img_input = gr.Image(type="pil", label="Input")
with gr.Column():
with gr.Row():
api_username = gr.Textbox(label="Danbooru API Username")
api_key = gr.Textbox(label="Danbooru API Key")
selected_ratings = gr.CheckboxGroup(
choices=["General", "Sensitive", "Questionable", "Explicit"],
value=["General", "Sensitive"],
label="Ratings",
)
with gr.Row():
n_neighbours = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="# of images",
)
find_btn = gr.Button("Find similar images")
similar_images = gr.Gallery(label="Similar images", columns=[5])
find_btn.click(
fn=searcher.predict,
inputs=[
img_input,
selected_ratings,
n_neighbours,
api_username,
api_key,
],
outputs=[similar_images],
)
demo.queue()
demo.launch(share=args.share)
if __name__ == "__main__":
main()