Image2Body_gradio / scripts /process_utils.py
yeq6x's picture
set_lora_device
a742794
import io
import os
import base64
from PIL import Image
import cv2
import numpy as np
from scripts.generate_prompt import load_wd14_tagger_model, generate_tags, preprocess_image as wd14_preprocess_image
from scripts.lineart_util import scribble_xdog, get_sketch, canny
from scripts.anime import init_model
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, AutoencoderKL
import gc
from peft import PeftModel
from dotenv import load_dotenv
from scripts.hf_utils import download_file
import spaces
# グローバル変数
use_local = False
model = None
device = None
torch_dtype = None # torch.float16 if device == "cuda" else torch.float32
sotai_gen_pipe = None
refine_gen_pipe = None
def get_file_path(filename, subfolder):
if use_local:
return subfolder + "/" + filename
else:
return download_file(filename, subfolder)
def ensure_rgb(image):
if image.mode != 'RGB':
return image.convert('RGB')
return image
def initialize(_use_local=False, use_gpu=False, use_dotenv=False):
if use_dotenv:
load_dotenv()
global model, sotai_gen_pipe, refine_gen_pipe, use_local, device, torch_dtype
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
use_local = _use_local
print(f"\nDevice: {device}, Local model: {_use_local}\n")
init_model(use_local)
model = load_wd14_tagger_model()
sotai_gen_pipe = initialize_sotai_model()
refine_gen_pipe = initialize_refine_model()
def load_lora(pipeline, lora_path, adapter_name, alpha=0.75):
pipeline.load_lora_weights(lora_path, adapter_name)
pipeline.fuse_lora(lora_scale=alpha, adapter_names=[adapter_name])
pipeline.set_lora_device(adapter_names=[adapter_name], device=device)
def initialize_sotai_model():
global device, torch_dtype
sotai_sd_model_path = get_file_path(os.environ["sotai_sd_model_name"], subfolder=os.environ["sd_models_dir"])
controlnet_path1 = get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
# controlnet_path1 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
controlnet_path2 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
# Load the Stable Diffusion model
sd_pipe = StableDiffusionPipeline.from_single_file(
sotai_sd_model_path,
torch_dtype=torch_dtype,
use_safetensors=True
).to(device)
# Load the ControlNet model
controlnet1 = ControlNetModel.from_single_file(
controlnet_path1,
torch_dtype=torch_dtype
).to(device)
# Load the ControlNet model
controlnet2 = ControlNetModel.from_single_file(
controlnet_path2,
torch_dtype=torch_dtype
).to(device)
# Create the ControlNet pipeline
sotai_gen_pipe = StableDiffusionControlNetPipeline(
vae=sd_pipe.vae,
text_encoder=sd_pipe.text_encoder,
tokenizer=sd_pipe.tokenizer,
unet=sd_pipe.unet,
scheduler=sd_pipe.scheduler,
safety_checker=sd_pipe.safety_checker,
feature_extractor=sd_pipe.feature_extractor,
controlnet=[controlnet1, controlnet2]
).to(device)
# LoRAの適用
lora_names = [
(os.environ["lora_name1"], 1.0),
# (os.environ["lora_name2"], 0.3),
]
# for lora_name, alpha in lora_names:
# lora_path = get_file_path(lora_name, subfolder=os.environ["lora_dir"])
# load_lora(sotai_gen_pipe, lora_path, adapter_name=lora_name.split(".")[0], alpha=alpha)
# スケジューラーの設定
sotai_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(sotai_gen_pipe.scheduler.config)
return sotai_gen_pipe
def initialize_refine_model():
global device, torch_dtype
refine_sd_model_path = get_file_path(os.environ["refine_sd_model_name"], subfolder=os.environ["sd_models_dir"])
controlnet_path3 = get_file_path(os.environ["controlnet_name3"], subfolder=os.environ["controlnet_dir1"])
controlnet_path4 = get_file_path(os.environ["controlnet_name4"], subfolder=os.environ["controlnet_dir1"])
vae_path = get_file_path(os.environ["vae_name"], subfolder=os.environ["vae_dir"])
# Load the Stable Diffusion model
sd_pipe = StableDiffusionPipeline.from_single_file(
refine_sd_model_path,
torch_dtype=torch_dtype,
variant="fp16",
use_safetensors=True
).to(device)
# controlnet_path = "models/cn/control_v11p_sd15_canny.pth"
controlnet1 = ControlNetModel.from_single_file(
controlnet_path3,
torch_dtype=torch_dtype
).to(device)
# Load the ControlNet model
controlnet2 = ControlNetModel.from_single_file(
controlnet_path4,
torch_dtype=torch_dtype
).to(device)
# Create the ControlNet pipeline
refine_gen_pipe = StableDiffusionControlNetPipeline(
vae=AutoencoderKL.from_single_file(vae_path, torch_dtype=torch_dtype).to(device),
text_encoder=sd_pipe.text_encoder,
tokenizer=sd_pipe.tokenizer,
unet=sd_pipe.unet,
scheduler=sd_pipe.scheduler,
safety_checker=sd_pipe.safety_checker,
feature_extractor=sd_pipe.feature_extractor,
controlnet=[controlnet1, controlnet2], # 複数のControlNetを指定
).to(device)
# スケジューラーの設定
refine_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(refine_gen_pipe.scheduler.config)
return refine_gen_pipe
def get_wd_tags(images: list) -> list:
global model
if model is None:
raise ValueError("Model is not initialized")
# initialize()
preprocessed_images = [wd14_preprocess_image(img) for img in images]
preprocessed_images = np.array(preprocessed_images)
return generate_tags(preprocessed_images, os.environ["wd_model_name"], model)
def preprocess_image_for_generation(image):
if isinstance(image, str): # base64文字列の場合
image = Image.open(io.BytesIO(base64.b64decode(image)))
elif isinstance(image, np.ndarray): # numpy配列の場合
image = Image.fromarray(image)
elif not isinstance(image, Image.Image):
raise ValueError("Unsupported image type")
# 画像サイズの計算
input_width, input_height = image.size
max_size = 736
output_width = max_size if input_height < input_width else int(input_width / input_height * max_size)
output_height = max_size if input_height > input_width else int(input_height / input_width * max_size)
image = image.resize((output_width, output_height))
return image, output_width, output_height
def binarize_image(image: Image.Image) -> np.ndarray:
image = np.array(image.convert('L'))
# 色反転
image = 255 - image
# ヒストグラム平坦化
clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(8, 8))
image = clahe.apply(image)
# ガウシアンブラー適用
image = cv2.GaussianBlur(image, (5, 5), 0)
# 適応的二値化
binary_image = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 9, -8)
return binary_image
def create_rgba_image(binary_image: np.ndarray, color: list) -> Image.Image:
rgba_image = np.zeros((binary_image.shape[0], binary_image.shape[1], 4), dtype=np.uint8)
rgba_image[:, :, 0] = color[0]
rgba_image[:, :, 1] = color[1]
rgba_image[:, :, 2] = color[2]
rgba_image[:, :, 3] = binary_image
return Image.fromarray(rgba_image, 'RGBA')
@spaces.GPU
def generate_sotai_image(input_image: Image.Image, output_width: int, output_height: int) -> Image.Image:
input_image = ensure_rgb(input_image)
global sotai_gen_pipe
if sotai_gen_pipe is None:
raise ValueError("Model is not initialized")
# initialize()
prompt = "anime pose, girl, (white background:1.5), (monochrome:1.5), full body, sketch, eyes, breasts, (slim legs, skinny legs:1.2)"
try:
# 入力画像のリサイズ
if input_image.size[0] > input_image.size[1]:
input_image = input_image.resize((512, int(512 * input_image.size[1] / input_image.size[0])))
else:
input_image = input_image.resize((int(512 * input_image.size[0] / input_image.size[1]), 512))
# EasyNegativeV2の内容
easy_negative_v2 = "(worst quality, low quality, normal quality:1.4), lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry, artist name, (bad_prompt_version2:0.8)"
output = sotai_gen_pipe(
prompt,
image=[input_image, input_image],
negative_prompt=f"(wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
# negative_prompt=f"{easy_negative_v2}, (wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
num_inference_steps=20,
guidance_scale=8,
width=output_width,
height=output_height,
denoising_strength=0.13,
num_images_per_prompt=1, # Equivalent to batch_size
guess_mode=[True, True], # Equivalent to pixel_perfect
controlnet_conditioning_scale=[1.4, 1.3], # 各ControlNetの重み
guidance_start=[0.0, 0.0],
guidance_end=[1.0, 1.0],
)
generated_image = output.images[0]
return generated_image
finally:
# メモリ解放
if device == "cuda":
torch.cuda.empty_cache()
gc.collect()
@spaces.GPU
def generate_refined_image(prompt: str, original_image: Image.Image, output_width: int, output_height: int, weight1: float, weight2: float) -> Image.Image:
original_image = ensure_rgb(original_image)
global refine_gen_pipe
if refine_gen_pipe is None:
raise ValueError("Model is not initialized")
# initialize()
try:
original_image_np = np.array(original_image)
# scribble_xdog
scribble_image, _ = scribble_xdog(original_image_np, 2048, 20)
original_image = original_image.resize((output_width, output_height))
output = refine_gen_pipe(
prompt,
image=[scribble_image, original_image], # 2つのControlNetに対応する入力画像
negative_prompt="extra limb, monochrome, black and white",
num_inference_steps=20,
width=output_width,
height=output_height,
controlnet_conditioning_scale=[weight1, weight2], # 各ControlNetの重み
control_guidance_start=[0.0, 0.0],
control_guidance_end=[1.0, 1.0],
guess_mode=[False, False], # pixel_perfect
)
generated_image = output.images[0]
return generated_image
finally:
# メモリ解放
if device == "cuda":
torch.cuda.empty_cache()
gc.collect()
def process_image(input_image, mode: str, weight1: float = 0.4, weight2: float = 0.3):
input_image = ensure_rgb(input_image)
# サイズを取得
input_width, input_height = input_image.size
max_size = 736
output_width = max_size if input_height < input_width else int(input_width / input_height * max_size)
output_height = max_size if input_height > input_width else int(input_height / input_width * max_size)
if mode == "refine":
# WD-14 taggerを使用してプロンプトを生成
image_np = np.array(ensure_rgb(input_image))
prompt = get_wd_tags([image_np])[0]
prompt = f"{prompt}"
refined_image = generate_refined_image(prompt, input_image, output_width, output_height, weight1, weight2)
refined_image = refined_image.convert('RGB')
# スケッチ画像を生成
refined_image_np = np.array(refined_image)
sketch_image = get_sketch(refined_image_np, "both", 2048, 10)
sketch_image = sketch_image.resize((output_width, output_height)) # 画像サイズを合わせる
# スケッチ画像の二値化
sketch_binary = binarize_image(sketch_image)
# RGBAに変換(透明なベース画像を作成)して、青い線を設定
sketch_image = create_rgba_image(sketch_binary, [0, 0, 255])
# 素体画像の生成
sotai_image = generate_sotai_image(refined_image, output_width, output_height)
elif mode == "original":
sotai_image = generate_sotai_image(input_image, output_width, output_height)
# スケッチ画像の生成
input_image_np = np.array(input_image)
sketch_image = get_sketch(input_image_np, "both", 2048, 16)
elif mode == "sketch":
# スケッチ画像の生成
input_image_np = np.array(input_image)
sketch_image = get_sketch(input_image_np, "both", 2048, 16)
# 素体画像の生成
sotai_image = generate_sotai_image(sketch_image, output_width, output_height)
else:
raise ValueError("Invalid mode")
# 素体画像の二値化
sotai_binary = binarize_image(sotai_image)
# RGBAに変換(透明なベース画像を作成)して、赤い線を設定
sotai_image = create_rgba_image(sotai_binary, [255, 0, 0])
return sotai_image, sketch_image
def image_to_base64(img_array):
buffered = io.BytesIO()
img_array.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def process_image_as_base64(input_image, mode: str, weight1: float = 0.4, weight2: float = 0.3):
sotai_image, sketch_image = process_image(input_image, mode, weight1, weight2)
return image_to_base64(sotai_image), image_to_base64(sketch_image)