import argparse import json import faiss import gradio as gr import numpy as np import requests from imgutils.tagging import wd14 TITLE = "## Danbooru Explorer" DESCRIPTION = """ Image similarity-based retrieval tool using: - [SmilingWolf/wd-swinv2-tagger-v3](https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3) as feature extractor - [dghs-imgutils](https://github.com/deepghs/imgutils) for feature extraction - [Faiss](https://github.com/facebookresearch/faiss) and [autofaiss](https://github.com/criteo/autofaiss) for indexing Also, check out [SmilingWolf/danbooru2022_embeddings_playground](https://huggingface.co/spaces/SmilingWolf/danbooru2022_embeddings_playground) for a similar space with experimental support for text input combined with image input. """ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--share", action="store_true") return parser.parse_args() def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): headers = {"User-Agent": "image_similarity_tool"} ratings_to_letters = { "General": "g", "Sensitive": "s", "Questionable": "q", "Explicit": "e", } acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" if api_username != "" and api_key != "": image_url = f"{image_url}?api_key={api_key}&login={api_username}" r = requests.get(image_url, headers=headers) if r.status_code != 200: return None content = json.loads(r.text) image_url = content["large_file_url"] if "large_file_url" in content else None image_url = image_url if content["rating"] in acceptable_ratings else None return image_url class SimilaritySearcher: def __init__(self): self.images_ids = np.load("index/cosine_ids.npy") self.knn_index = faiss.read_index("index/cosine_knn.index") config = json.loads(open("index/cosine_infos.json").read())["index_param"] faiss.ParameterSpace().set_index_parameters(self.knn_index, config) def predict( self, img_input, selected_ratings, n_neighbours, api_username, api_key, ): embeddings = wd14.get_wd14_tags( img_input, model_name="SwinV2_v3", fmt=("embedding"), ) embeddings = np.expand_dims(embeddings, 0) faiss.normalize_L2(embeddings) dists, indexes = self.knn_index.search(embeddings, k=n_neighbours) neighbours_ids = self.images_ids[indexes][0] neighbours_ids = [int(x) for x in neighbours_ids] captions = [] image_urls = [] for image_id, dist in zip(neighbours_ids, dists[0]): current_url = danbooru_id_to_url( image_id, selected_ratings, api_username, api_key, ) if current_url is not None: image_urls.append(current_url) captions.append(f"{image_id}/{dist:.2f}") return list(zip(image_urls, captions)) def main(): args = parse_args() searcher = SimilaritySearcher() with gr.Blocks() as demo: gr.Markdown(TITLE) gr.Markdown(DESCRIPTION) with gr.Row(): img_input = gr.Image(type="pil", label="Input") with gr.Column(): with gr.Row(): api_username = gr.Textbox(label="Danbooru API Username") api_key = gr.Textbox(label="Danbooru API Key") selected_ratings = gr.CheckboxGroup( choices=["General", "Sensitive", "Questionable", "Explicit"], value=["General", "Sensitive"], label="Ratings", ) with gr.Row(): n_neighbours = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="# of images", ) find_btn = gr.Button("Find similar images") similar_images = gr.Gallery(label="Similar images", columns=[5]) find_btn.click( fn=searcher.predict, inputs=[ img_input, selected_ratings, n_neighbours, api_username, api_key, ], outputs=[similar_images], ) demo.queue() demo.launch(share=args.share) if __name__ == "__main__": main()