|
import os, re, cv2 |
|
from typing import Mapping, Tuple, Dict |
|
import gradio as gr |
|
import numpy as np |
|
import io |
|
import pandas as pd |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download |
|
from onnxruntime import InferenceSession |
|
|
|
|
|
def make_square(img, target_size): |
|
old_size = img.shape[:2] |
|
desired_size = max(old_size) |
|
desired_size = max(desired_size, target_size) |
|
|
|
delta_w = desired_size - old_size[1] |
|
delta_h = desired_size - old_size[0] |
|
top, bottom = delta_h // 2, delta_h - (delta_h // 2) |
|
left, right = delta_w // 2, delta_w - (delta_w // 2) |
|
|
|
color = [255, 255, 255] |
|
return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) |
|
|
|
|
|
|
|
def smart_resize(img, size): |
|
|
|
if img.shape[0] > size: |
|
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) |
|
elif img.shape[0] < size: |
|
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) |
|
else: |
|
pass |
|
|
|
return img |
|
|
|
|
|
class WaifuDiffusionInterrogator: |
|
def __init__( |
|
self, |
|
repo='SmilingWolf/wd-v1-4-vit-tagger', |
|
model_path='model.onnx', |
|
tags_path='selected_tags.csv', |
|
mode: str = "auto" |
|
) -> None: |
|
self.__repo = repo |
|
self.__model_path = model_path |
|
self.__tags_path = tags_path |
|
self._provider_mode = mode |
|
|
|
self.__initialized = False |
|
self._model, self._tags = None, None |
|
def _init(self) -> None: |
|
if self.__initialized: |
|
return |
|
|
|
model_path = hf_hub_download(self.__repo, filename=self.__model_path) |
|
tags_path = hf_hub_download(self.__repo, filename=self.__tags_path) |
|
|
|
self._model = InferenceSession(str(model_path)) |
|
self._tags = pd.read_csv(tags_path) |
|
|
|
self.__initialized = True |
|
|
|
def _calculation(self, image: Image.Image) -> pd.DataFrame: |
|
|
|
self._init() |
|
|
|
|
|
|
|
|
|
|
|
|
|
_, height, _, _ = self._model.get_inputs()[0].shape |
|
|
|
|
|
image = image.convert('RGBA') |
|
new_image = Image.new('RGBA', image.size, 'WHITE') |
|
new_image.paste(image, mask=image) |
|
image = new_image.convert('RGB') |
|
image = np.asarray(image) |
|
|
|
|
|
image = image[:, :, ::-1] |
|
|
|
image = make_square(image, height) |
|
image = smart_resize(image, height) |
|
image = image.astype(np.float32) |
|
image = np.expand_dims(image, 0) |
|
|
|
|
|
input_name = self._model.get_inputs()[0].name |
|
label_name = self._model.get_outputs()[0].name |
|
confidence = self._model.run([label_name], {input_name: image})[0] |
|
|
|
full_tags = self._tags[['name', 'category']].copy() |
|
full_tags['confidence'] = confidence[0] |
|
|
|
return full_tags |
|
def interrogate(self, image: Image) -> Tuple[Dict[str, float], Dict[str, float]]: |
|
|
|
|
|
full_tags = self._calculation(image) |
|
|
|
|
|
ratings = dict(full_tags[full_tags['category'] == 9][['name', 'confidence']].values) |
|
|
|
|
|
tags = dict(full_tags[full_tags['category'] != 9][['name', 'confidence']].values) |
|
|
|
return ratings, tags |
|
|
|
|
|
WAIFU_MODELS: Mapping[str, WaifuDiffusionInterrogator] = { |
|
'chen-vit': WaifuDiffusionInterrogator(), |
|
'chen-convnext': WaifuDiffusionInterrogator( |
|
repo='SmilingWolf/wd-v1-4-convnext-tagger' |
|
), |
|
'chen-convnext2': WaifuDiffusionInterrogator( |
|
repo="SmilingWolf/wd-v1-4-convnextv2-tagger-v2" |
|
), |
|
'chen-swinv2': WaifuDiffusionInterrogator( |
|
repo='SmilingWolf/wd-v1-4-swinv2-tagger-v2' |
|
), |
|
'chen-moat2': WaifuDiffusionInterrogator( |
|
repo='SmilingWolf/wd-v1-4-moat-tagger-v2' |
|
), |
|
'chen-convnext3': WaifuDiffusionInterrogator( |
|
repo='SmilingWolf/wd-convnext-tagger-v3' |
|
), |
|
'chen-vit3': WaifuDiffusionInterrogator( |
|
repo='SmilingWolf/wd-vit-tagger-v3' |
|
), |
|
'chen-swinv3': WaifuDiffusionInterrogator( |
|
repo='SmilingWolf/wd-swinv2-tagger-v3' |
|
), |
|
} |
|
RE_SPECIAL = re.compile(r'([\\()])') |
|
|
|
|
|
def image_to_wd14_tags(image: Image.Image, model_name: str, threshold: float, |
|
use_spaces: bool, use_escape: bool, include_ranks=False, score_descend=True) \ |
|
-> Tuple[Mapping[str, float], str, Mapping[str, float]]: |
|
model = WAIFU_MODELS[model_name] |
|
ratings, tags = model.interrogate(image) |
|
|
|
filtered_tags = { |
|
tag: score for tag, score in tags.items() |
|
if score >= threshold |
|
} |
|
|
|
text_items = [] |
|
tags_pairs = filtered_tags.items() |
|
if score_descend: |
|
tags_pairs = sorted(tags_pairs, key=lambda x: (-x[1], x[0])) |
|
for tag, score in tags_pairs: |
|
tag_outformat = tag |
|
if use_spaces: |
|
tag_outformat = tag_outformat.replace('_', '-') |
|
else: |
|
tag_outformat = tag_outformat.replace(' ', ', ') |
|
tag_outformat = tag_outformat.replace('_', ' ') |
|
if use_escape: |
|
tag_outformat = re.sub(RE_SPECIAL, r'\\\1', tag_outformat) |
|
if include_ranks: |
|
tag_outformat = f"({tag_outformat}:{score:.3f})" |
|
text_items.append(tag_outformat) |
|
if use_spaces: |
|
output_text = ' '.join(text_items) |
|
else: |
|
output_text = ', '.join(text_items) |
|
|
|
return ratings, output_text, filtered_tags |
|
|
|
|
|
if __name__ == '__main__': |
|
with gr.Blocks(analytics_enabled=False, theme="NoCrypt/miku") as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr_input_image = gr.Image(type='pil', label='Chen Chen') |
|
with gr.Row(): |
|
gr_model = gr.Radio(list(WAIFU_MODELS.keys()), value='chen-moat2', label='Chen') |
|
gr_threshold = gr.Slider(0.0, 1.0, 0.5, label='Chen Chen Chen Chen Chen') |
|
with gr.Row(): |
|
gr_space = gr.Checkbox(value=True, label='Use DashSpace') |
|
gr_escape = gr.Checkbox(value=True, label='Chen Text Escape') |
|
|
|
gr_btn_submit = gr.Button(value='橙', variant='primary') |
|
|
|
with gr.Column(): |
|
gr_ratings = gr.Label(label='橙 橙') |
|
with gr.Tabs(): |
|
with gr.Tab("Chens"): |
|
gr_tags = gr.Label(label='Chens') |
|
with gr.Tab("Chen Text"): |
|
gr_output_text = gr.TextArea(label='Chen Text') |
|
|
|
gr_btn_submit.click( |
|
image_to_wd14_tags, |
|
inputs=[gr_input_image, gr_model, gr_threshold, gr_space, gr_escape], |
|
outputs=[gr_ratings, gr_output_text, gr_tags], |
|
api_name="classify" |
|
) |
|
demo.queue(os.cpu_count()).launch() |