IronJayx
fixed inputs
7d125a7
import os
import time
import gradio as gr
from gradio_imageslider import ImageSlider
from comfydeploy import ComfyDeploy
from PIL import Image
import requests
from io import BytesIO
import base64
import glob
from dotenv import load_dotenv
load_dotenv()
API_KEY = os.environ.get("COMFY_DEPLOY_API_KEY")
DEPLOYMENT_ID = os.environ.get("COMFY_DEPLOYMENT_ID")
if not API_KEY or not DEPLOYMENT_ID:
raise ValueError(
"Please set COMFY_DEPLOY_API_KEY and COMFY_DEPLOYMENT_ID in your environment variables"
)
client = ComfyDeploy(bearer_auth=API_KEY)
def get_base64_from_image(image: Image.Image) -> str:
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def get_profile(profile) -> dict:
return {
"username": profile.username,
"profile": profile.profile,
"name": profile.name,
}
async def process(
image: Image.Image | None = None,
profile: gr.OAuthProfile | None = None,
progress: gr.Progress = gr.Progress(),
) -> tuple[Image.Image, Image.Image] | None:
if not image:
gr.Info("Please upload an image ")
return None
if profile is None:
gr.Info("Please log in to process the image.")
return None
user_data = get_profile(profile)
print("--------- RUN ----------")
print(user_data)
progress(0, desc="Preparing inputs...")
image_base64 = get_base64_from_image(image)
inputs = {
"image": f"data:image/png;base64,{image_base64}",
**{k: str(v) for k, v in params.items()},
}
output = await process_image(inputs, progress)
progress(100, desc="Processing completed")
return image, output
async def process_image(inputs: dict, progress: gr.Progress) -> Image.Image | None:
try:
result = client.run.create(
request={"deployment_id": DEPLOYMENT_ID, "inputs": inputs}
)
if result and result.object:
run_id: str = result.object.run_id
progress(0, desc="Starting processing...")
while True:
run_result = client.run.get(run_id=run_id)
if not run_result.object:
continue
progress_value = run_result.object.progress or 0
status = run_result.object.live_status or "Cold starting..."
progress(progress_value, desc=f"Status: {status}")
if run_result.object.status == "success":
for output in run_result.object.outputs or []:
if output.data and output.data.images:
image_url: str = output.data.images[0].url
response = requests.get(image_url)
processed_image = Image.open(BytesIO(response.content))
return processed_image
elif run_result.object.status == "failed":
print("Processing failed")
return None
time.sleep(1) # Wait for 1 second before checking the status again
except Exception as e:
print(f"Error: {e}")
return None
def load_preset_images():
image_files = glob.glob("images/inputs/*")
return [
{"name": img, "image": Image.open(img)}
for img in image_files
if Image.open(img).format.lower()
in ["png", "jpg", "jpeg", "gif", "bmp", "webp"]
]
def build_example(input_image_path):
output_image_path = input_image_path.replace("inputs", "outputs")
return [
input_image_path,
0.4,
10,
1024,
1,
4,
0,
1,
0.7,
(input_image_path, output_image_path),
]
def serialize_params(params: dict) -> dict:
return {
key: {"value": param.value, "label": param.label}
for key, param in params.items()
}
with gr.Blocks() as demo:
gr.HTML("""
<div style="display: flex; justify-content: center; text-align:center; flex-direction: column;">
<h1 style="color: #333;">🌊 Creative Image Upscaler</h1>
<div style="max-width: 800px; margin: 0 auto;">
<p style="font-size: 16px;">Upload an image and adjust the parameters to enhance your image.</p>
<p style="font-size: 16px;">Click on the <b>"Run"</b> button to process the image and compare the original and processed images using the slider.</p>
<p style="font-size: 16px;">⚠️ Note that the images are compressed to reduce the workloads of the demo.</p>
</div>
</div>
""")
with gr.Row(equal_height=False):
with gr.Column():
# The image overflow, fix
input_image = gr.Image(type="pil", label="Input Image", interactive=True)
with gr.Accordion("Avanced parameters", open=False):
params = {
"denoise": gr.Slider(0, 1, value=0.4, label="Denoise"),
"steps": gr.Slider(1, 25, value=10, label="Steps"),
"tile_size": gr.Slider(256, 2048, value=1024, label="Tile Size"),
"downscale": gr.Slider(1, 4, value=1, label="Downscale"),
"upscale": gr.Slider(1, 4, value=4, label="Upscale"),
"color_match": gr.Slider(0, 1, value=0, label="Color Match"),
"controlnet_tile_end": gr.Slider(
0, 1, value=1, label="ControlNet Tile End"
),
"controlnet_tile_strength": gr.Slider(
0, 1, value=0.7, label="ControlNet Tile Strength"
),
}
with gr.Column():
image_slider = ImageSlider(
label="Compare Original and Processed", interactive=False
)
login_button = gr.LoginButton(scale=8)
process_btn = gr.Button("Run", variant="primary", size="lg")
process_btn.click(
fn=lambda _: gr.update(interactive=False, value="Processing..."),
inputs=[],
outputs=[process_btn],
api_name=False,
).then(
fn=process,
inputs=[
input_image,
],
outputs=[image_slider],
api_name=False,
).then(
fn=lambda _: gr.update(interactive=True, value="Run"),
inputs=[],
outputs=[process_btn],
api_name=False,
)
examples = [build_example(img) for img in glob.glob("images/inputs/*")]
gr.Examples(examples=examples, inputs=[input_image, *params.values(), image_slider])
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True)