Spaces:
Sleeping
Sleeping
File size: 4,595 Bytes
cf12300 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import threading
from collections import deque
from dataclasses import dataclass
from typing import Optional
import gradio as gr
from PIL import Image
from constants import DESCRIPTION, LOGO
from gradio_examples import EXAMPLES
from model import get_pipeline
from utils import replace_background
MAX_QUEUE_SIZE = 4
pipeline = get_pipeline()
@dataclass
class GenerationState:
prompts: deque
generations: deque
def get_initial_state() -> GenerationState:
return GenerationState(
prompts=deque(maxlen=MAX_QUEUE_SIZE),
generations=deque(maxlen=MAX_QUEUE_SIZE),
)
def load_initial_state(request: gr.Request) -> GenerationState:
print("Loading initial state for", request.client.host)
print("Total number of active threads", threading.active_count())
return get_initial_state()
async def put_to_queue(
image: Optional[Image.Image],
prompt: str,
seed: int,
strength: float,
state: GenerationState,
):
prompts_queue = state.prompts
if prompt and image is not None:
prompts_queue.append((image, prompt, seed, strength))
return state
def inference(state: GenerationState) -> Image.Image:
prompts_queue = state.prompts
generations_queue = state.generations
if len(prompts_queue) == 0:
return state
image, prompt, seed, strength = prompts_queue.popleft()
original_image_size = image.size
image = replace_background(image.resize((512, 512)))
result = pipeline(
prompt=prompt,
image=image,
strength=strength,
seed=seed,
guidance_scale=1,
num_inference_steps=4,
)
output_image = result.images[0].resize(original_image_size)
generations_queue.append(output_image)
return state
def update_output_image(state: GenerationState):
image_update = gr.update()
generations_queue = state.generations
if len(generations_queue) > 0:
generated_image = generations_queue.popleft()
image_update = gr.update(value=generated_image)
return image_update, state
with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo:
generation_state = gr.State(get_initial_state())
gr.HTML(f'<div style="width: 70px;">{LOGO}</div>')
gr.Markdown(DESCRIPTION)
with gr.Row(variant="default"):
input_image = gr.Image(
tool="color-sketch",
source="canvas",
label="Initial Image",
type="pil",
height=512,
width=512,
brush_radius=40.0,
)
output_image = gr.Image(
label="Generated Image",
type="pil",
interactive=False,
elem_id="output_image",
)
with gr.Row():
with gr.Column():
prompt_box = gr.Textbox(label="Prompt", value=EXAMPLES[0])
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
with gr.Column():
strength = gr.Slider(
label="Strength",
minimum=0.1,
maximum=1.0,
step=0.05,
value=0.8,
info="""
Strength of the initial image that will be applied during inference.
""",
)
with gr.Column():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2**31 - 1,
step=1,
randomize=True,
info="""
Seed for the random number generator.
""",
)
demo.load(
load_initial_state,
outputs=[generation_state],
)
demo.load(
inference,
inputs=[generation_state],
outputs=[generation_state],
every=0.1,
)
demo.load(
update_output_image,
inputs=[generation_state],
outputs=[output_image, generation_state],
every=0.1,
)
for event in [input_image.change, prompt_box.change, strength.change, seed.change]:
event(
put_to_queue,
[input_image, prompt_box, seed, strength, generation_state],
[generation_state],
show_progress=False,
queue=True,
)
gr.Markdown("## Example Prompts")
gr.Examples(examples=EXAMPLES, inputs=[prompt_box], label="Examples")
if __name__ == "__main__":
demo.queue(concurrency_count=20, api_open=False).launch(max_threads=1024)
|