File size: 4,363 Bytes
5bb6a56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
import gradio as gr
# from PIL import Image
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from diffusers import UNet2DConditionModel, AutoencoderKL
from diffusers import EulerDiscreteScheduler

root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

# Initialize global variables for models and pipeline
text_encoder = None
tokenizer = None
vae = None
scheduler = None
unet = None
pipe = None

def load_models():
    global text_encoder, tokenizer, vae, scheduler, unet, pipe

    if text_encoder is None:
        ckpt_dir = f'{root_dir}/weights/Kolors'

        # Load the text encoder on CPU (this speeds stuff up 2x)
        text_encoder = ChatGLMModel.from_pretrained(
            f'{ckpt_dir}/text_encoder',
            torch_dtype=torch.float16).to('cpu').half()
        tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')

        # Load the VAE and UNet on GPU
        vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to('cuda')
        scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
        unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to('cuda')

        # Prepare the pipeline
        pipe = StableDiffusionXLPipeline(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            force_zeros_for_empty_prompt=False)
        pipe = pipe.to("cuda")
        pipe.enable_model_cpu_offload()  # Enable offloading to balance CPU/GPU usage

def infer(prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale, num_images_per_prompt):
    load_models()

    if use_random_seed:
        seed = torch.randint(0, 2**32 - 1, (1,)).item()

    generator = torch.Generator(pipe.device).manual_seed(seed)
    images = pipe(
        prompt=prompt,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
        generator=generator
    ).images

    saved_images = []
    output_dir = f'{root_dir}/scripts/outputs'
    os.makedirs(output_dir, exist_ok=True)

    for i, image in enumerate(images):
        file_path = os.path.join(output_dir, 'sample_test.jpg')
        base_name, ext = os.path.splitext(file_path)
        counter = 1
        while os.path.exists(file_path):
            file_path = f"{base_name}_{counter}{ext}"
            counter += 1
        image.save(file_path)
        saved_images.append(file_path)

    return saved_images

def gradio_interface():
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                gr.Markdown("## Kolors: Diffusion Model Gradio Interface")
                prompt = gr.Textbox(label="Prompt")
                use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)
                seed = gr.Slider(minimum=0, maximum=2**32 - 1, step=1, label="Seed", randomize=True, visible=False)
                use_random_seed.change(lambda x: gr.update(visible=not x), use_random_seed, seed)
                height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=1024)
                width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=1024)
                num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Inference Steps", value=50)
                guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0)
                num_images_per_prompt = gr.Slider(minimum=1, maximum=10, step=1, label="Images per Prompt", value=1)
                btn = gr.Button("Generate Image")

            with gr.Column():
                output_images = gr.Gallery(label="Output Images", elem_id="output_gallery")

        btn.click(
            fn=infer,
            inputs=[prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale, num_images_per_prompt],
            outputs=output_images
        )

    return demo

if __name__ == '__main__':
    gradio_interface().launch()