import os import random import uuid import gradio as gr import numpy as np from PIL import Image import spaces import torch from diffusers import StableDiffusion3Pipeline, DPMSolverMultistepScheduler, AutoencoderKL, StableDiffusion3Img2ImgPipeline from transformers import T5EncoderModel, BitsAndBytesConfig from huggingface_hub import login huggingface_token = os.getenv("HUGGINGFACE_TOKEN") login(token=huggingface_token) DESCRIPTION = """# Stable Diffusion 3""" if not torch.cuda.is_available(): DESCRIPTION += "\n

Running on CPU 🥶 This demo may not work on CPU.

" MAX_SEED = np.iinfo(np.int32).max CACHE_EXAMPLES = False MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536")) USE_TORCH_COMPILE = False ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def load_pipeline(): model_id = "stabilityai/stable-diffusion-3-medium-diffusers" pipe = StableDiffusion3Pipeline.from_pretrained( model_id, #device_map="balanced", torch_dtype=torch.float16 ) return pipe aspect_ratios = { "21:9": (21, 9), "2:1": (2, 1), "16:9": (16, 9), "5:4": (5, 4), "4:3": (4, 3), "3:2": (3, 2), "1:1": (1, 1), } # Function to calculate resolution def calculate_resolution(aspect_ratio, mode='landscape', total_pixels=1024*1024, divisibility=64): if aspect_ratio not in aspect_ratios: raise ValueError(f"Invalid aspect ratio: {aspect_ratio}") width_multiplier, height_multiplier = aspect_ratios[aspect_ratio] ratio = width_multiplier / height_multiplier if mode == 'portrait': # Swap the ratio for portrait mode ratio = 1 / ratio height = int((total_pixels / ratio) ** 0.5) height -= height % divisibility width = int(height * ratio) width -= width % divisibility while width * height > total_pixels: height -= divisibility width = int(height * ratio) width -= width % divisibility return width, height def save_image(img): unique_name = str(uuid.uuid4()) + ".png" img.save(unique_name) return unique_name def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed @spaces.GPU def generate( prompt:str, negative_prompt: str = "", use_negative_prompt: bool = False, seed: int = 0, aspect: str = "1:1", mode: str = "landscape", guidance_scale: float = 7.5, randomize_seed: bool = False, num_inference_steps=30, NUM_IMAGES_PER_PROMPT=1, use_resolution_binning: bool = True, progress=gr.Progress(track_tqdm=True), ): pipe = load_pipeline() pipe.to(device) seed = int(randomize_seed_fn(seed, randomize_seed)) generator = torch.Generator().manual_seed(seed) if not use_negative_prompt: negative_prompt = None # type: ignore width, height = calculate_resolution(aspect, mode) output = pipe( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, num_images_per_prompt=NUM_IMAGES_PER_PROMPT, output_type="pil", ).images return output examples = [ "Beautiful pixel art of a wizard with hovering text \"Achievement unlocked: Diffusion models can spell now\"", "Frog sitting in a 1950s diner wearing a leather jacket and a top hat. on the table a giant burger and a small sign that says \"froggy fridays\"", "This dreamlike digital art capture a vibrant kaleidoscopic bird in a rainforest", "pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo saying \"SD3 rocks!\"", "post-apocalyptic city wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic", "a dark-armored warrior with ornate golden details, cloaked in a flowing black cape, wielding a radiant, fiery sword, standing amidst an ominous cloudy backdrop with dramatic lighting, exuding a menacing, powerful presence.", "A wise old wizard with a long white beard, flowing robes, and a gnarled staff, casting a spell, photorealistic style", "Design a film poster for a noir thriller set in 1940s Los Angeles, featuring a shadowy figure under a streetlamp and a foggy, mysterious ambiance.", ] css = ''' .gradio-container{max-width: 1000px !important} h1{text-align:center} ''' with gr.Blocks(css=css) as demo: with gr.Row(): with gr.Column(): gr.HTML( """

Stable Diffusion 3

""" ) with gr.Group(): with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0) with gr.Row(): aspect = gr.Dropdown(label='Aspect Ratio', choices=list(aspect_ratios.keys()), value='1:1', interactive=True) mode = gr.Dropdown(label='Mode', choices=['landscape', 'portrait'], value='landscape') result = gr.Gallery(label="Result", elem_id="gallery", show_label=False) with gr.Accordion("Advanced options", open=False): with gr.Row(): use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True) negative_prompt = gr.Text( label="Negative prompt", max_lines=1, value = "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW", visible=True, ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) steps = gr.Slider( label="Steps", minimum=0, maximum=60, step=1, value=30, ) number_image = gr.Slider( label="Number of Images", minimum=1, maximum=2, step=1, value=1, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): guidance_scale = gr.Slider( label="Guidance Scale", minimum=0.1, maximum=10, step=0.1, value=7.0, ) gr.Examples( examples=examples, inputs=prompt, outputs=[result], fn=generate, cache_examples=CACHE_EXAMPLES, ) use_negative_prompt.change( fn=lambda x: gr.update(visible=x), inputs=use_negative_prompt, outputs=negative_prompt, api_name=False, ) gr.on( triggers=[ prompt.submit, negative_prompt.submit, run_button.click, ], fn=generate, inputs=[ prompt, negative_prompt, use_negative_prompt, seed, aspect, mode, guidance_scale, randomize_seed, steps, number_image, ], outputs=[result], api_name="run", ) if __name__ == "__main__": demo.queue().launch()