image-label-2 / img_label.py
pengdaqian
init scan
d39fc00
raw
history blame
7.46 kB
from __future__ import annotations
import functools
import io
import urllib
from typing import Tuple, List, Any
import huggingface_hub
import onnxruntime as rt
import pandas as pd
import numpy as np
import PIL.Image
import requests
import dbimutils
import piexif
import piexif.helper
from urllib.request import urlopen
import model
HF_TOKEN = ""
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
def change_model(model_name):
global loaded_models
if model_name == "SwinV2":
model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
elif model_name == "ConvNext":
model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
elif model_name == "ConvNextV2":
model = load_model(CONV2_MODEL_REPO, MODEL_FILENAME)
elif model_name == "ViT":
model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
loaded_models[model_name] = model
return loaded_models[model_name]
def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
path = huggingface_hub.hf_hub_download(
model_repo, model_filename, use_auth_token=HF_TOKEN
)
model = rt.InferenceSession(path)
return model
def load_labels() -> tuple[list[Any], list[Any], list[Any], list[Any]]:
path = huggingface_hub.hf_hub_download(
CONV2_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
)
df = pd.read_csv(path)
tag_names = df["name"].tolist()
rating_indexes = list(np.where(df["category"] == 9)[0])
general_indexes = list(np.where(df["category"] == 0)[0])
character_indexes = list(np.where(df["category"] == 4)[0])
return tag_names, rating_indexes, general_indexes, character_indexes
def predict(
image: PIL.Image.Image,
model_name: str,
general_threshold: float,
character_threshold: float,
tag_names: list[str],
rating_indexes: list[np.int64],
general_indexes: list[np.int64],
character_indexes: list[np.int64],
):
global loaded_models
if isinstance(image, str):
rawimage = dbimutils.read_img_from_url(image)
elif isinstance(image, PIL.Image.Image):
rawimage = image
else:
raise Exception("Invalid image type")
image = rawimage
model = loaded_models[model_name]
if model is None:
model = change_model(model_name)
_, height, width, _ = model.get_inputs()[0].shape
# Alpha to white
image = image.convert("RGBA")
new_image = PIL.Image.new("RGBA", image.size, "WHITE")
new_image.paste(image, mask=image)
image = new_image.convert("RGB")
image = np.asarray(image)
# PIL RGB to OpenCV BGR
image = image[:, :, ::-1]
image = dbimutils.make_square(image, height)
image = dbimutils.smart_resize(image, height)
image = image.astype(np.float32)
image = np.expand_dims(image, 0)
input_name = model.get_inputs()[0].name
label_name = model.get_outputs()[0].name
probs = model.run([label_name], {input_name: image})[0]
labels = list(zip(tag_names, probs[0].astype(float)))
# First 4 labels are actually ratings: pick one with argmax
ratings_names = [labels[i] for i in rating_indexes]
rating = dict(ratings_names)
# Then we have general tags: pick any where prediction confidence > threshold
general_names = [labels[i] for i in general_indexes]
general_res = [x for x in general_names if x[1] > general_threshold]
general_res = dict(general_res)
# Everything else is characters: pick any where prediction confidence > threshold
character_names = [labels[i] for i in character_indexes]
character_res = [x for x in character_names if x[1] > character_threshold]
character_res = dict(character_res)
b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
a = (
", ".join(list(b.keys()))
.replace("_", " ")
.replace("(", "\(")
.replace(")", "\)")
)
c = ", ".join(list(b.keys()))
items = rawimage.info
geninfo = ""
if "exif" in rawimage.info:
exif = piexif.load(rawimage.info["exif"])
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"")
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
except ValueError:
exif_comment = exif_comment.decode("utf8", errors="ignore")
items["exif comment"] = exif_comment
geninfo = exif_comment
for field in [
"jfif",
"jfif_version",
"jfif_unit",
"jfif_density",
"dpi",
"exif",
"loop",
"background",
"timestamp",
"duration",
]:
items.pop(field, None)
geninfo = items.get("parameters", geninfo)
for key, text in items.items():
print(key)
print(text)
print("geninfo", geninfo)
print("a", a)
print("c", c)
print("rating", rating)
print("character_res", character_res)
print("general_res", general_res)
character_res = list(filter(lambda x: x['confidence'] > 0.4, [{'tag': tag, 'confidence': score}
for tag, score in character_res.items()]))
general_res = list(filter(lambda x: x['confidence'] > 0.4, [{'tag': tag, 'confidence': score}
for tag, score in general_res.items()]))
return {'a': a, 'c': c, 'rating': rating, 'character_res': character_res, 'general_res': general_res}
def label_img(
image: PIL.Image.Image | str,
model: str,
# model: (["SwinV2", "ConvNext", "ConvNextV2", "ViT"], value="ConvNextV2", label="Model"),
l_score_general_threshold: float,
l_score_character_threshold: float,
):
if isinstance(image, str) and image.startswith("http"):
image = dbimutils.read_img_from_url(image)
global loaded_models
loaded_models = {"SwinV2": None, "ConvNext": None, "ConvNextV2": None, "ViT": None}
change_model("ConvNextV2")
tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
func = functools.partial(
predict,
tag_names=tag_names,
rating_indexes=rating_indexes,
general_indexes=general_indexes,
character_indexes=character_indexes,
)
return func(
image=image, model_name=model,
general_threshold=l_score_general_threshold,
character_threshold=l_score_character_threshold,
)
def write_image_tag(img_id: int, is_valid: bool, tags: List[model.ImageTag], callback_url: str):
model.ImageScanCallbackRequest(img_id=img_id, is_valid=is_valid, tags=tags)
if __name__ == "__main__":
score_slider_step = 0.05
score_general_threshold = 0.35
score_character_threshold = 0.85
ret = label_img(
image='https://pub-9747017e9ec54620bfbe2385f14fe4d7.r2.dev/cnGirlYcy_v10_people_network_nannansleep/cnGirlYcy_v10_people_network_nannansleep_r_1679670778_0.png',
model="SwinV2",
l_score_general_threshold=score_general_threshold,
l_score_character_threshold=score_character_threshold,
)
print(ret)