import threading from collections import deque from dataclasses import dataclass from typing import Optional import gradio as gr from PIL import Image from constants import DESCRIPTION, LOGO from gradio_examples import EXAMPLES from model import get_pipeline from utils import replace_background MAX_QUEUE_SIZE = 4 pipeline = get_pipeline() @dataclass class GenerationState: prompts: deque generations: deque def get_initial_state() -> GenerationState: return GenerationState( prompts=deque(maxlen=MAX_QUEUE_SIZE), generations=deque(maxlen=MAX_QUEUE_SIZE), ) def load_initial_state(request: gr.Request) -> GenerationState: print("Loading initial state for", request.client.host) print("Total number of active threads", threading.active_count()) return get_initial_state() async def put_to_queue( image: Optional[Image.Image], prompt: str, seed: int, strength: float, state: GenerationState, ): prompts_queue = state.prompts if prompt and image is not None: prompts_queue.append((image, prompt, seed, strength)) return state def inference(state: GenerationState) -> Image.Image: prompts_queue = state.prompts generations_queue = state.generations if len(prompts_queue) == 0: return state image, prompt, seed, strength = prompts_queue.popleft() original_image_size = image.size image = replace_background(image.resize((512, 512))) result = pipeline( prompt=prompt, image=image, strength=strength, seed=seed, guidance_scale=1, num_inference_steps=4, ) output_image = result.images[0].resize(original_image_size) generations_queue.append(output_image) return state def update_output_image(state: GenerationState): image_update = gr.update() generations_queue = state.generations if len(generations_queue) > 0: generated_image = generations_queue.popleft() image_update = gr.update(value=generated_image) return image_update, state with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo: generation_state = gr.State(get_initial_state()) gr.HTML(f'
{LOGO}
') gr.Markdown(DESCRIPTION) with gr.Row(variant="default"): input_image = gr.Image( tool="color-sketch", source="canvas", label="Initial Image", type="pil", height=512, width=512, brush_radius=40.0, ) output_image = gr.Image( label="Generated Image", type="pil", interactive=False, elem_id="output_image", ) with gr.Row(): with gr.Column(): prompt_box = gr.Textbox(label="Prompt", value=EXAMPLES[0]) with gr.Accordion(label="Advanced Options", open=False): with gr.Row(): with gr.Column(): strength = gr.Slider( label="Strength", minimum=0.1, maximum=1.0, step=0.05, value=0.8, info=""" Strength of the initial image that will be applied during inference. """, ) with gr.Column(): seed = gr.Slider( label="Seed", minimum=0, maximum=2**31 - 1, step=1, randomize=True, info=""" Seed for the random number generator. """, ) demo.load( load_initial_state, outputs=[generation_state], ) demo.load( inference, inputs=[generation_state], outputs=[generation_state], every=0.1, ) demo.load( update_output_image, inputs=[generation_state], outputs=[output_image, generation_state], every=0.1, ) for event in [input_image.change, prompt_box.change, strength.change, seed.change]: event( put_to_queue, [input_image, prompt_box, seed, strength, generation_state], [generation_state], show_progress=False, queue=True, ) gr.Markdown("## Example Prompts") gr.Examples(examples=EXAMPLES, inputs=[prompt_box], label="Examples") if __name__ == "__main__": demo.queue(concurrency_count=20, api_open=False).launch(max_threads=1024)