Stick2Body / app.py
tori29umai's picture
Update app.py
fe724f3 verified
import spaces
import gradio as gr
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
from PIL import Image
import os
import time
from utils.dl_utils import dl_cn_model, dl_cn_config, dl_lora_model
from utils.image_utils import resize_image_aspect_ratio, base_generation
from utils.prompt_utils import remove_duplicates
# Setup directories and download necessary models
path = os.getcwd()
cn_dir = f"{path}/controlnet"
lora_dir = f"{path}/lora"
os.makedirs(cn_dir, exist_ok=True)
os.makedirs(lora_dir, exist_ok=True)
dl_cn_model(cn_dir)
dl_cn_config(cn_dir)
dl_lora_model(lora_dir)
# Model loading function
def load_model(lora_dir, cn_dir):
dtype = torch.float16
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
"cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.load_lora_weights(lora_dir, weight_name="Fixhands_anime_bdsqlsz_V1.safetensors")
return pipe
# Image prediction and processing function
@spaces.GPU(duration=120)
def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
pipe = load_model(lora_dir, cn_dir)
input_image = Image.open(input_image_path)
base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
resize_image = resize_image_aspect_ratio(input_image)
resize_base_image = resize_image_aspect_ratio(base_image)
generator = torch.manual_seed(0)
last_time = time.time()
prompt = "masterpiece, best quality, simple background, white background, bald, nude, " + prompt
prompt = remove_duplicates(prompt)
print(prompt)
output_image = pipe(
image=resize_base_image,
control_image=resize_image,
strength=1.0,
prompt=prompt,
negative_prompt=negative_prompt,
controlnet_conditioning_scale=float(controlnet_scale),
generator=generator,
num_inference_steps=30,
eta=1.0,
).images[0]
print(f"Time taken: {time.time() - last_time}")
output_image = output_image.resize(input_image.size, Image.LANCZOS)
return output_image
class Img2Img:
def __init__(self):
self.demo = self.layout()
self.tagger_model = None
self.input_image_path = None
self.canny_image = None
def layout(self):
css = """
#intro{
max-width: 32rem;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
gr.Image(value="title.png", label="Title Image")
gr.Markdown("### Stickman to Posing Doll Image Converter\n\nこのアプリは棒人間をポーズ人形画像に変換するアプリです。入力する棒人間の形状は以下のリンクを参考にしてください。\nある程度形状が一致していれば手書きの棒人間でも認識されます\n\n[VRoid Hub Character Example](https://hub.vroid.com/characters/4765753841994800453/models/6738034259079048708)")
with gr.Row():
with gr.Column(scale=1):
self.input_image_path = gr.Image(label="Input Image", type='filepath')
self.prompt = gr.Textbox(label="Prompt", lines=3)
self.negative_prompt = gr.Textbox(label="Negative Prompt", lines=3, value="nsfw, nipples, bad anatomy, liquid fingers, low quality, worst quality, out of focus, ugly, error, jpeg artifacts, lowers, blurry, bokeh")
self.controlnet_scale = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.01, label="Controlnet Scale")
generate_button = gr.Button("Generate")
with gr.Column(scale=1):
self.output_image = gr.Image(type="pil", label="Output Image")
generate_button.click(
fn=predict,
inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
outputs=self.output_image
)
return demo
img2img = Img2Img()
img2img.demo.launch(share=True)