|
import gradio as gr |
|
import numpy as np |
|
import random |
|
from diffusers import DiffusionPipeline, LMSDiscreteScheduler |
|
import torch |
|
import time |
|
|
|
|
|
DEVICE = "cpu" |
|
|
|
|
|
MODEL_OPTIONS = { |
|
"Medium Quality (Faster)": "stabilityai/stable-diffusion-2-base", |
|
"Fastest (Draft Quality)": "hf-internal-testing/tiny-stable-diffusion-pipe", |
|
} |
|
|
|
|
|
DEFAULT_MODEL_ID = MODEL_OPTIONS["Fastest (Draft Quality)"] |
|
|
|
|
|
PIPELINES = {} |
|
|
|
def load_pipeline(model_id): |
|
if model_id in PIPELINES: |
|
return PIPELINES[model_id] |
|
else: |
|
pipe = DiffusionPipeline.from_pretrained( |
|
model_id, torch_dtype=torch.float32 |
|
) |
|
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) |
|
pipe = pipe.to(DEVICE) |
|
PIPELINES[model_id] = pipe |
|
return pipe |
|
|
|
def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, num_images, model_choice): |
|
if not prompt: |
|
raise gr.Error("Будь ласка, введіть опис для зображення.") |
|
|
|
pipe = load_pipeline(MODEL_OPTIONS[model_choice]) |
|
|
|
generator = torch.Generator(device=DEVICE) |
|
if not randomize_seed: |
|
generator = generator.manual_seed(seed) |
|
|
|
start_time = time.time() |
|
images = pipe( |
|
prompt, |
|
negative_prompt=negative_prompt, |
|
width=width, |
|
height=height, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
num_images_per_prompt=num_images, |
|
generator=generator, |
|
).images |
|
|
|
end_time = time.time() |
|
generation_time = end_time - start_time |
|
|
|
return images, f"Час генерації: {generation_time:.2f} секунд" |
|
|
|
|
|
|
|
|
|
run_button = gr.Button("Згенерувати") |
|
gallery = gr.Gallery(label="Згенеровані зображення") |
|
status_text = gr.Textbox(label="Статус") |
|
|
|
run_button.click( |
|
fn=generate_image, |
|
inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, num_images, model_choice], |
|
outputs=[gallery, status_text], |
|
) |
|
|