Image2Body_gradio / scripts /generate_prompt.py
yeq6x's picture
@spaces.GPU
e7d7d62
raw
history blame
3.3 kB
import argparse
import csv
import os
import json
from PIL import Image
import cv2
import numpy as np
from tensorflow.keras.layers import TFSMLayer
from huggingface_hub import hf_hub_download
from pathlib import Path
import spaces
# 画像サイズの設定
IMAGE_SIZE = 448
# デフォルトのタグ付けリポジトリとファイル構成
DEFAULT_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
MODEL_FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
VAR_DIR = "variables"
VAR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = MODEL_FILES[-1]
def preprocess_image(image):
"""画像を前処理して正方形に変換"""
img = np.array(image)[:, :, ::-1] # RGB->BGR
size = max(img.shape[:2])
pad_x, pad_y = size - img.shape[1], size - img.shape[0]
img = np.pad(img, ((pad_y // 2, pad_y - pad_y // 2), (pad_x // 2, pad_x - pad_x // 2), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
return img.astype(np.float32)
def download_model_files(repo_id, model_dir, sub_dir, files, sub_files):
"""モデルファイルをHugging Face Hubからダウンロード"""
for file in files:
hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
for file in sub_files:
hf_hub_download(repo_id, file, subfolder=sub_dir, cache_dir=os.path.join(model_dir, sub_dir), force_download=True, force_filename=file)
@spaces.GPU
def load_wd14_tagger_model():
"""WD14タグ付けモデルをロード"""
model_dir = "wd14_tagger_model"
if not os.path.exists(model_dir):
download_model_files(DEFAULT_REPO, model_dir, VAR_DIR, MODEL_FILES, VAR_FILES)
else:
print("Using existing model")
return TFSMLayer(model_dir, call_endpoint='serving_default')
def read_tags_from_csv(csv_path):
"""CSVファイルからタグを読み取る"""
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
tags = [row for row in reader]
header = tags[0]
rows = tags[1:]
assert header[:3] == ["tag_id", "name", "category"], f"Unexpected CSV format: {header}"
return rows
def generate_tags(images, model_dir, model):
"""画像にタグを生成"""
rows = read_tags_from_csv(os.path.join(model_dir, CSV_FILE))
general_tags = [row[1] for row in rows if row[2] == "0"]
character_tags = [row[1] for row in rows if row[2] == "4"]
tag_freq = {}
undesired_tags = {'one-piece_swimsuit', 'swimsuit', 'leotard', 'saitama_(one-punch_man)', '1boy'}
probs = model(images, training=False)['predictions_sigmoid'].numpy()
tag_text_list = []
for prob in probs:
tags_combined = []
for i, p in enumerate(prob[4:]):
tag_list = general_tags if i < len(general_tags) else character_tags
tag = tag_list[i - len(general_tags)] if i >= len(general_tags) else tag_list[i]
if p >= 0.35 and tag not in undesired_tags:
tag_freq[tag] = tag_freq.get(tag, 0) + 1
tags_combined.append(tag)
tag_text_list.append(", ".join(tags_combined))
return tag_text_list