|
import argparse |
|
import json |
|
|
|
import faiss |
|
import gradio as gr |
|
import numpy as np |
|
import requests |
|
from imgutils.tagging import wd14 |
|
|
|
TITLE = "## Danbooru Explorer" |
|
DESCRIPTION = """ |
|
Image similarity-based retrieval tool using: |
|
- [SmilingWolf/wd-swinv2-tagger-v3](https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3) as feature extractor |
|
- [dghs-imgutils](https://github.com/deepghs/imgutils) for feature extraction |
|
- [Faiss](https://github.com/facebookresearch/faiss) and [autofaiss](https://github.com/criteo/autofaiss) for indexing |
|
|
|
Also, check out [SmilingWolf/danbooru2022_embeddings_playground](https://huggingface.co/spaces/SmilingWolf/danbooru2022_embeddings_playground) for a similar space with experimental support for text input combined with image input. |
|
""" |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--share", action="store_true") |
|
return parser.parse_args() |
|
|
|
|
|
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): |
|
headers = {"User-Agent": "image_similarity_tool"} |
|
ratings_to_letters = { |
|
"General": "g", |
|
"Sensitive": "s", |
|
"Questionable": "q", |
|
"Explicit": "e", |
|
} |
|
|
|
acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] |
|
|
|
image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" |
|
if api_username != "" and api_key != "": |
|
image_url = f"{image_url}?api_key={api_key}&login={api_username}" |
|
|
|
r = requests.get(image_url, headers=headers) |
|
if r.status_code != 200: |
|
return None |
|
|
|
content = json.loads(r.text) |
|
image_url = content["large_file_url"] if "large_file_url" in content else None |
|
image_url = image_url if content["rating"] in acceptable_ratings else None |
|
return image_url |
|
|
|
|
|
class SimilaritySearcher: |
|
def __init__(self): |
|
self.images_ids = np.load("index/cosine_ids.npy") |
|
self.knn_index = faiss.read_index("index/cosine_knn.index") |
|
|
|
config = json.loads(open("index/cosine_infos.json").read())["index_param"] |
|
faiss.ParameterSpace().set_index_parameters(self.knn_index, config) |
|
|
|
def predict( |
|
self, |
|
img_input, |
|
selected_ratings, |
|
n_neighbours, |
|
api_username, |
|
api_key, |
|
): |
|
embeddings = wd14.get_wd14_tags( |
|
img_input, |
|
model_name="SwinV2_v3", |
|
fmt=("embedding"), |
|
) |
|
embeddings = np.expand_dims(embeddings, 0) |
|
faiss.normalize_L2(embeddings) |
|
|
|
dists, indexes = self.knn_index.search(embeddings, k=n_neighbours) |
|
neighbours_ids = self.images_ids[indexes][0] |
|
neighbours_ids = [int(x) for x in neighbours_ids] |
|
|
|
captions = [] |
|
image_urls = [] |
|
for image_id, dist in zip(neighbours_ids, dists[0]): |
|
current_url = danbooru_id_to_url( |
|
image_id, |
|
selected_ratings, |
|
api_username, |
|
api_key, |
|
) |
|
if current_url is not None: |
|
image_urls.append(current_url) |
|
captions.append(f"{image_id}/{dist:.2f}") |
|
return list(zip(image_urls, captions)) |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
searcher = SimilaritySearcher() |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(TITLE) |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(): |
|
img_input = gr.Image(type="pil", label="Input") |
|
with gr.Column(): |
|
with gr.Row(): |
|
api_username = gr.Textbox(label="Danbooru API Username") |
|
api_key = gr.Textbox(label="Danbooru API Key") |
|
selected_ratings = gr.CheckboxGroup( |
|
choices=["General", "Sensitive", "Questionable", "Explicit"], |
|
value=["General", "Sensitive"], |
|
label="Ratings", |
|
) |
|
with gr.Row(): |
|
n_neighbours = gr.Slider( |
|
minimum=1, |
|
maximum=20, |
|
value=5, |
|
step=1, |
|
label="# of images", |
|
) |
|
find_btn = gr.Button("Find similar images") |
|
similar_images = gr.Gallery(label="Similar images", columns=[5]) |
|
|
|
find_btn.click( |
|
fn=searcher.predict, |
|
inputs=[ |
|
img_input, |
|
selected_ratings, |
|
n_neighbours, |
|
api_username, |
|
api_key, |
|
], |
|
outputs=[similar_images], |
|
) |
|
|
|
demo.queue() |
|
demo.launch(share=args.share) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|