Spaces:
Sleeping
Sleeping
import threading | |
from collections import deque | |
from dataclasses import dataclass | |
from typing import Optional | |
import gradio as gr | |
from PIL import Image | |
from constants import DESCRIPTION, LOGO | |
from gradio_examples import EXAMPLES | |
from model import get_pipeline | |
from utils import replace_background | |
MAX_QUEUE_SIZE = 4 | |
pipeline = get_pipeline() | |
class GenerationState: | |
prompts: deque | |
generations: deque | |
def get_initial_state() -> GenerationState: | |
return GenerationState( | |
prompts=deque(maxlen=MAX_QUEUE_SIZE), | |
generations=deque(maxlen=MAX_QUEUE_SIZE), | |
) | |
def load_initial_state(request: gr.Request) -> GenerationState: | |
print("Loading initial state for", request.client.host) | |
print("Total number of active threads", threading.active_count()) | |
return get_initial_state() | |
async def put_to_queue( | |
image: Optional[Image.Image], | |
prompt: str, | |
seed: int, | |
strength: float, | |
state: GenerationState, | |
): | |
prompts_queue = state.prompts | |
if prompt and image is not None: | |
prompts_queue.append((image, prompt, seed, strength)) | |
return state | |
def inference(state: GenerationState) -> Image.Image: | |
prompts_queue = state.prompts | |
generations_queue = state.generations | |
if len(prompts_queue) == 0: | |
return state | |
image, prompt, seed, strength = prompts_queue.popleft() | |
original_image_size = image.size | |
image = replace_background(image.resize((512, 512))) | |
result = pipeline( | |
prompt=prompt, | |
image=image, | |
strength=strength, | |
seed=seed, | |
guidance_scale=1, | |
num_inference_steps=4, | |
) | |
output_image = result.images[0].resize(original_image_size) | |
generations_queue.append(output_image) | |
return state | |
def update_output_image(state: GenerationState): | |
image_update = gr.update() | |
generations_queue = state.generations | |
if len(generations_queue) > 0: | |
generated_image = generations_queue.popleft() | |
image_update = gr.update(value=generated_image) | |
return image_update, state | |
with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo: | |
generation_state = gr.State(get_initial_state()) | |
gr.HTML(f'<div style="width: 70px;">{LOGO}</div>') | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(variant="default"): | |
input_image = gr.Image( | |
tool="color-sketch", | |
source="canvas", | |
label="Initial Image", | |
type="pil", | |
height=512, | |
width=512, | |
brush_radius=40.0, | |
) | |
output_image = gr.Image( | |
label="Generated Image", | |
type="pil", | |
interactive=False, | |
elem_id="output_image", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
prompt_box = gr.Textbox(label="Prompt", value=EXAMPLES[0]) | |
with gr.Accordion(label="Advanced Options", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
strength = gr.Slider( | |
label="Strength", | |
minimum=0.1, | |
maximum=1.0, | |
step=0.05, | |
value=0.8, | |
info=""" | |
Strength of the initial image that will be applied during inference. | |
""", | |
) | |
with gr.Column(): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=2**31 - 1, | |
step=1, | |
randomize=True, | |
info=""" | |
Seed for the random number generator. | |
""", | |
) | |
demo.load( | |
load_initial_state, | |
outputs=[generation_state], | |
) | |
demo.load( | |
inference, | |
inputs=[generation_state], | |
outputs=[generation_state], | |
every=0.1, | |
) | |
demo.load( | |
update_output_image, | |
inputs=[generation_state], | |
outputs=[output_image, generation_state], | |
every=0.1, | |
) | |
for event in [input_image.change, prompt_box.change, strength.change, seed.change]: | |
event( | |
put_to_queue, | |
[input_image, prompt_box, seed, strength, generation_state], | |
[generation_state], | |
show_progress=False, | |
queue=True, | |
) | |
gr.Markdown("## Example Prompts") | |
gr.Examples(examples=EXAMPLES, inputs=[prompt_box], label="Examples") | |
if __name__ == "__main__": | |
demo.queue(concurrency_count=20, api_open=False).launch(max_threads=1024) | |