Spaces:
Runtime error
Runtime error
live params!
Browse files- app-img2img.py +15 -16
- img2img/index.html +11 -13
app-img2img.py
CHANGED
@@ -49,7 +49,7 @@ else:
|
|
49 |
pipe.set_progress_bar_config(disable=True)
|
50 |
pipe.to(torch_device="cuda", torch_dtype=torch.float32)
|
51 |
pipe.unet.to(memory_format=torch.channels_last)
|
52 |
-
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
53 |
user_queue_map = {}
|
54 |
|
55 |
# for torch.compile
|
@@ -58,7 +58,7 @@ pipe(prompt="warmup", image=[Image.new("RGB", (512, 512))])
|
|
58 |
def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232):
|
59 |
generator = torch.manual_seed(seed)
|
60 |
# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
|
61 |
-
num_inference_steps =
|
62 |
results = pipe(
|
63 |
prompt=prompt,
|
64 |
# generator=generator,
|
@@ -66,7 +66,7 @@ def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232)
|
|
66 |
strength=strength,
|
67 |
num_inference_steps=num_inference_steps,
|
68 |
guidance_scale=guidance_scale,
|
69 |
-
lcm_origin_steps=
|
70 |
output_type="pil",
|
71 |
)
|
72 |
nsfw_content_detected = (
|
@@ -111,11 +111,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
111 |
await websocket.send_json(
|
112 |
{"status": "success", "message": "Connected", "userId": uid}
|
113 |
)
|
114 |
-
params = await websocket.receive_json()
|
115 |
-
params = InputParams(**params)
|
116 |
user_queue_map[uid] = {
|
117 |
-
"queue": asyncio.Queue()
|
118 |
-
"params": params,
|
119 |
}
|
120 |
await websocket.send_json(
|
121 |
{"status": "start", "message": "Start Streaming", "userId": uid}
|
@@ -148,19 +145,16 @@ async def stream(user_id: uuid.UUID):
|
|
148 |
try:
|
149 |
user_queue = user_queue_map[uid]
|
150 |
queue = user_queue["queue"]
|
151 |
-
|
152 |
-
seed = params.seed
|
153 |
-
prompt = params.prompt
|
154 |
-
strength = params.strength
|
155 |
-
guidance_scale = params.guidance_scale
|
156 |
-
|
157 |
async def generate():
|
158 |
while True:
|
159 |
-
|
|
|
|
|
160 |
if input_image is None:
|
161 |
continue
|
162 |
|
163 |
-
image = predict(input_image, prompt, guidance_scale, strength, seed)
|
164 |
if image is None:
|
165 |
continue
|
166 |
frame_data = io.BytesIO()
|
@@ -190,6 +184,8 @@ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
|
|
190 |
try:
|
191 |
while True:
|
192 |
data = await websocket.receive_bytes()
|
|
|
|
|
193 |
pil_image = Image.open(io.BytesIO(data))
|
194 |
|
195 |
while not queue.empty():
|
@@ -197,7 +193,10 @@ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
|
|
197 |
queue.get_nowait()
|
198 |
except asyncio.QueueEmpty:
|
199 |
continue
|
200 |
-
await queue.put(
|
|
|
|
|
|
|
201 |
if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
|
202 |
await websocket.send_json(
|
203 |
{
|
|
|
49 |
pipe.set_progress_bar_config(disable=True)
|
50 |
pipe.to(torch_device="cuda", torch_dtype=torch.float32)
|
51 |
pipe.unet.to(memory_format=torch.channels_last)
|
52 |
+
# pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
53 |
user_queue_map = {}
|
54 |
|
55 |
# for torch.compile
|
|
|
58 |
def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232):
|
59 |
generator = torch.manual_seed(seed)
|
60 |
# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
|
61 |
+
num_inference_steps = 3
|
62 |
results = pipe(
|
63 |
prompt=prompt,
|
64 |
# generator=generator,
|
|
|
66 |
strength=strength,
|
67 |
num_inference_steps=num_inference_steps,
|
68 |
guidance_scale=guidance_scale,
|
69 |
+
lcm_origin_steps=20,
|
70 |
output_type="pil",
|
71 |
)
|
72 |
nsfw_content_detected = (
|
|
|
111 |
await websocket.send_json(
|
112 |
{"status": "success", "message": "Connected", "userId": uid}
|
113 |
)
|
|
|
|
|
114 |
user_queue_map[uid] = {
|
115 |
+
"queue": asyncio.Queue()
|
|
|
116 |
}
|
117 |
await websocket.send_json(
|
118 |
{"status": "start", "message": "Start Streaming", "userId": uid}
|
|
|
145 |
try:
|
146 |
user_queue = user_queue_map[uid]
|
147 |
queue = user_queue["queue"]
|
148 |
+
|
|
|
|
|
|
|
|
|
|
|
149 |
async def generate():
|
150 |
while True:
|
151 |
+
data = await queue.get()
|
152 |
+
input_image = data["image"]
|
153 |
+
params = data["params"]
|
154 |
if input_image is None:
|
155 |
continue
|
156 |
|
157 |
+
image = predict(input_image, params.prompt, params.guidance_scale, params.strength, params.seed)
|
158 |
if image is None:
|
159 |
continue
|
160 |
frame_data = io.BytesIO()
|
|
|
184 |
try:
|
185 |
while True:
|
186 |
data = await websocket.receive_bytes()
|
187 |
+
params = await websocket.receive_json()
|
188 |
+
params = InputParams(**params)
|
189 |
pil_image = Image.open(io.BytesIO(data))
|
190 |
|
191 |
while not queue.empty():
|
|
|
193 |
queue.get_nowait()
|
194 |
except asyncio.QueueEmpty:
|
195 |
continue
|
196 |
+
await queue.put({
|
197 |
+
"image": pil_image,
|
198 |
+
"params": params
|
199 |
+
})
|
200 |
if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
|
201 |
await websocket.send_json(
|
202 |
{
|
img2img/index.html
CHANGED
@@ -21,10 +21,10 @@
|
|
21 |
const queueSizeEl = document.querySelector("#queue_size");
|
22 |
const errorEl = document.querySelector("#error");
|
23 |
|
24 |
-
function LCMLive(webcamVideo, liveImage) {
|
25 |
let websocket;
|
26 |
|
27 |
-
async function start(
|
28 |
return new Promise((resolve, reject) => {
|
29 |
const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
|
30 |
}:${window.location.host}/ws`;
|
@@ -46,7 +46,6 @@
|
|
46 |
const data = JSON.parse(event.data);
|
47 |
switch (data.status) {
|
48 |
case "success":
|
49 |
-
socket.send(JSON.stringify(params));
|
50 |
break;
|
51 |
case "start":
|
52 |
const userId = data.userId;
|
@@ -71,6 +70,12 @@
|
|
71 |
ctx.drawImage(webcamVideo, 0, 0, canvas.width, canvas.height);
|
72 |
const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
|
73 |
websocket.send(blob);
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
}
|
75 |
|
76 |
function initVideoStream(userId) {
|
@@ -124,15 +129,11 @@
|
|
124 |
}
|
125 |
|
126 |
|
127 |
-
const lcmLive = LCMLive(videoEl, imageEl);
|
128 |
startBtn.addEventListener("click", async () => {
|
129 |
try {
|
130 |
-
const seed = seedEl.value;
|
131 |
-
const prompt = promptEl.value;
|
132 |
-
const guidance_scale = guidanceEl.value;
|
133 |
-
const strength = strengthEl.value;
|
134 |
startBtn.disabled = true;
|
135 |
-
const res = await lcmLive.start(
|
136 |
startBtn.disabled = false;
|
137 |
if (res.status === "timeout")
|
138 |
toggleMessage("success")
|
@@ -176,9 +177,6 @@
|
|
176 |
target="_blank" class="text-blue-500 hover:underline">Diffusers</a> with a MJPEG
|
177 |
stream server.
|
178 |
</p>
|
179 |
-
<p class="text-sm">
|
180 |
-
To change settings or prompt, stop the current stream and start a new one.
|
181 |
-
</p>
|
182 |
<p class="text-sm">
|
183 |
There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU, affecting
|
184 |
real-time performance. Maximum queue size is 4. <a
|
@@ -218,7 +216,7 @@
|
|
218 |
<input type="number" id="seed" name="seed" value="299792458"
|
219 |
class="font-light border border-gray-700 text-right rounded-md p-2">
|
220 |
<button
|
221 |
-
onclick="document.querySelector('#seed').value =
|
222 |
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
|
223 |
Rand
|
224 |
</button>
|
|
|
21 |
const queueSizeEl = document.querySelector("#queue_size");
|
22 |
const errorEl = document.querySelector("#error");
|
23 |
|
24 |
+
function LCMLive(webcamVideo, liveImage, seedEl, promptEl, guidanceEl, strengthEl) {
|
25 |
let websocket;
|
26 |
|
27 |
+
async function start() {
|
28 |
return new Promise((resolve, reject) => {
|
29 |
const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
|
30 |
}:${window.location.host}/ws`;
|
|
|
46 |
const data = JSON.parse(event.data);
|
47 |
switch (data.status) {
|
48 |
case "success":
|
|
|
49 |
break;
|
50 |
case "start":
|
51 |
const userId = data.userId;
|
|
|
70 |
ctx.drawImage(webcamVideo, 0, 0, canvas.width, canvas.height);
|
71 |
const blob = await canvas.convertToBlob({ type: "image/jpeg", quality: 1 });
|
72 |
websocket.send(blob);
|
73 |
+
websocket.send(JSON.stringify({
|
74 |
+
"seed": seedEl.value,
|
75 |
+
"prompt": promptEl.value,
|
76 |
+
"guidance_scale": guidanceEl.value,
|
77 |
+
"strength": strengthEl.value
|
78 |
+
}));
|
79 |
}
|
80 |
|
81 |
function initVideoStream(userId) {
|
|
|
129 |
}
|
130 |
|
131 |
|
132 |
+
const lcmLive = LCMLive(videoEl, imageEl, seedEl, promptEl, guidanceEl, strengthEl);
|
133 |
startBtn.addEventListener("click", async () => {
|
134 |
try {
|
|
|
|
|
|
|
|
|
135 |
startBtn.disabled = true;
|
136 |
+
const res = await lcmLive.start();
|
137 |
startBtn.disabled = false;
|
138 |
if (res.status === "timeout")
|
139 |
toggleMessage("success")
|
|
|
177 |
target="_blank" class="text-blue-500 hover:underline">Diffusers</a> with a MJPEG
|
178 |
stream server.
|
179 |
</p>
|
|
|
|
|
|
|
180 |
<p class="text-sm">
|
181 |
There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU, affecting
|
182 |
real-time performance. Maximum queue size is 4. <a
|
|
|
216 |
<input type="number" id="seed" name="seed" value="299792458"
|
217 |
class="font-light border border-gray-700 text-right rounded-md p-2">
|
218 |
<button
|
219 |
+
onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
|
220 |
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
|
221 |
Rand
|
222 |
</button>
|