deepdanbooru_onnx / README.md
skytnt's picture
Update README.md
817c2b7
|
raw
history blame
4.66 kB
metadata
license: mit

Model convert from https://github.com/KichangKim/DeepDanbooru

Usage:

Basic use

import cv2
import numpy as np
import onnxruntime as rt
from huggingface_hub import hf_hub_download

tagger_model_path = hf_hub_download(repo_id="skytnt/deepdanbooru_onnx", filename="deepdanbooru.onnx")

tagger_model = rt.InferenceSession(tagger_model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
tagger_model_meta = tagger_model.get_modelmeta().custom_metadata_map
tagger_tags = eval(tagger_model_meta['tags'])

def tagger_predict(image, score_threshold):
    s = 512
    h, w = image.shape[:-1]
    h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
    ph, pw = s - h, s - w
    image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
    image = cv2.copyMakeBorder(image, ph // 2, ph - ph // 2, pw // 2, pw - pw // 2, cv2.BORDER_REPLICATE)
    image = image.astype(np.float32) / 255
    image = img_new[np.newaxis, :]
    probs = tagger_model.run(None, {"input_1": image})[0][0]
    probs = probs.astype(np.float32)
    res = []
    for prob, label in zip(probs.tolist(), tagger_tags):
        if prob < score_threshold:
            continue
        res.append(label)
    return res

img = cv2.imread("test.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
tags = tagger_predict(img, 0.5)
print(tags)

Multi-gpu batch process

import cv2
import torch
import os
import numpy as np
import onnxruntime as rt
from huggingface_hub import hf_hub_download
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm
from threading import Thread


class MyDataset(Dataset):
    def __init__(self, image_list):
        self.image_list = image_list

    def __len__(self):
        length = len(self.image_list)
        return length

    def __getitem__(self, index):
        image = Image.open(self.image_list[index]).convert("RGB")
        image = np.asarray(image)
        s = 512
        h, w = image.shape[:-1]
        h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
        ph, pw = s - h, s - w
        image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
        image = cv2.copyMakeBorder(image, ph // 2, ph - ph // 2, pw // 2, pw - pw // 2, cv2.BORDER_REPLICATE)
        image = image.astype(np.float32) / 255
        image = torch.from_numpy(image)
        idx = torch.tensor([index], dtype=torch.int32)
        return image, idx


def get_images(path):
    def file_ext(fname):
        return os.path.splitext(fname)[1].lower()

    all_files = {
        os.path.relpath(os.path.join(root, fname), path)
        for root, _dirs, files in os.walk(path)
        for fname in files
    }
    all_images = sorted(
        os.path.join(path, fname) for fname in all_files if file_ext(fname) in [".png", ".jpg", ".jpeg"]
    )
    print(len(all_images))
    return all_images


def process(all_images, batch_size=8, score_threshold=0.35):
    predictions = {}

    def work_fn(images, device_id):
        dataset = MyDataset(images)
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            persistent_workers=True,
            num_workers=4,
            pin_memory=True,
        )
        for data in tqdm(dataloader):
            image, idxs = data
            image = image.numpy()
            probs = tagger_model[device_id].run(None, {"input_1": image})[0]
            probs = probs.astype(np.float32)
            bs = probs.shape[0]
            for i in range(bs):
                tags = []
                for prob, label in zip(probs[i].tolist(), tagger_tags):
                    if prob > score_threshold:
                        tags.append((label, prob))
                predictions[images[idxs[i].item()]] = tags

    gpu_num = len(tagger_model)
    image_num = (len(all_images) // gpu_num) + 1
    ts = [Thread(target=work_fn, args=(all_images[i * image_num:(i + 1) * image_num], i)) for i in range(gpu_num)]
    for t in ts:
        t.start()
    for t in ts:
        t.join()
    return predictions


gpu_num = 4
batch_size = 8
tagger_model_path = hf_hub_download(repo_id="skytnt/deepdanbooru_onnx", filename="deepdanbooru.onnx")
tagger_model = [
    rt.InferenceSession(tagger_model_path, providers=['CUDAExecutionProvider'], provider_options=[{'device_id': i}]) for
    i in range(gpu_num)]
tagger_model_meta = tagger_model[0].get_modelmeta().custom_metadata_map
tagger_tags = eval(tagger_model_meta['tags'])

all_images = get_images("./data")
predictions = process(all_images, batch_size)