Spaces:
Paused
Paused
import os | |
import time | |
import gradio as gr | |
from gradio_imageslider import ImageSlider | |
from comfydeploy import ComfyDeploy | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import base64 | |
import glob | |
from dotenv import load_dotenv | |
load_dotenv() | |
API_KEY = os.environ.get("COMFY_DEPLOY_API_KEY") | |
DEPLOYMENT_ID = os.environ.get("COMFY_DEPLOYMENT_ID") | |
if not API_KEY or not DEPLOYMENT_ID: | |
raise ValueError( | |
"Please set COMFY_DEPLOY_API_KEY and COMFY_DEPLOYMENT_ID in your environment variables" | |
) | |
client = ComfyDeploy(bearer_auth=API_KEY) | |
def get_base64_from_image(image: Image.Image) -> str: | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def get_profile(profile) -> dict: | |
return { | |
"username": profile.username, | |
"profile": profile.profile, | |
"name": profile.name, | |
} | |
async def process( | |
image: Image.Image | None = None, | |
profile: gr.OAuthProfile | None = None, | |
progress: gr.Progress = gr.Progress(), | |
) -> tuple[Image.Image, Image.Image] | None: | |
if not image: | |
gr.Info("Please upload an image ") | |
return None | |
if profile is None: | |
gr.Info("Please log in to process the image.") | |
return None | |
user_data = get_profile(profile) | |
print("--------- RUN ----------") | |
print(user_data) | |
progress(0, desc="Preparing inputs...") | |
image_base64 = get_base64_from_image(image) | |
inputs = { | |
"image": f"data:image/png;base64,{image_base64}", | |
**{k: str(v) for k, v in params.items()}, | |
} | |
output = await process_image(inputs, progress) | |
progress(100, desc="Processing completed") | |
return image, output | |
async def process_image(inputs: dict, progress: gr.Progress) -> Image.Image | None: | |
try: | |
result = client.run.create( | |
request={"deployment_id": DEPLOYMENT_ID, "inputs": inputs} | |
) | |
if result and result.object: | |
run_id: str = result.object.run_id | |
progress(0, desc="Starting processing...") | |
while True: | |
run_result = client.run.get(run_id=run_id) | |
if not run_result.object: | |
continue | |
progress_value = run_result.object.progress or 0 | |
status = run_result.object.live_status or "Cold starting..." | |
progress(progress_value, desc=f"Status: {status}") | |
if run_result.object.status == "success": | |
for output in run_result.object.outputs or []: | |
if output.data and output.data.images: | |
image_url: str = output.data.images[0].url | |
response = requests.get(image_url) | |
processed_image = Image.open(BytesIO(response.content)) | |
return processed_image | |
elif run_result.object.status == "failed": | |
print("Processing failed") | |
return None | |
time.sleep(1) # Wait for 1 second before checking the status again | |
except Exception as e: | |
print(f"Error: {e}") | |
return None | |
def load_preset_images(): | |
image_files = glob.glob("images/inputs/*") | |
return [ | |
{"name": img, "image": Image.open(img)} | |
for img in image_files | |
if Image.open(img).format.lower() | |
in ["png", "jpg", "jpeg", "gif", "bmp", "webp"] | |
] | |
def build_example(input_image_path): | |
output_image_path = input_image_path.replace("inputs", "outputs") | |
return [ | |
input_image_path, | |
0.4, | |
10, | |
1024, | |
1, | |
4, | |
0, | |
1, | |
0.7, | |
(input_image_path, output_image_path), | |
] | |
def serialize_params(params: dict) -> dict: | |
return { | |
key: {"value": param.value, "label": param.label} | |
for key, param in params.items() | |
} | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
<div style="display: flex; justify-content: center; text-align:center; flex-direction: column;"> | |
<h1 style="color: #333;">π Creative Image Upscaler</h1> | |
<div style="max-width: 800px; margin: 0 auto;"> | |
<p style="font-size: 16px;">Upload an image and adjust the parameters to enhance your image.</p> | |
<p style="font-size: 16px;">Click on the <b>"Run"</b> button to process the image and compare the original and processed images using the slider.</p> | |
<p style="font-size: 16px;">β οΈ Note that the images are compressed to reduce the workloads of the demo.</p> | |
</div> | |
</div> | |
""") | |
with gr.Row(equal_height=False): | |
with gr.Column(): | |
# The image overflow, fix | |
input_image = gr.Image(type="pil", label="Input Image", interactive=True) | |
with gr.Accordion("Avanced parameters", open=False): | |
params = { | |
"denoise": gr.Slider(0, 1, value=0.4, label="Denoise"), | |
"steps": gr.Slider(1, 25, value=10, label="Steps"), | |
"tile_size": gr.Slider(256, 2048, value=1024, label="Tile Size"), | |
"downscale": gr.Slider(1, 4, value=1, label="Downscale"), | |
"upscale": gr.Slider(1, 4, value=4, label="Upscale"), | |
"color_match": gr.Slider(0, 1, value=0, label="Color Match"), | |
"controlnet_tile_end": gr.Slider( | |
0, 1, value=1, label="ControlNet Tile End" | |
), | |
"controlnet_tile_strength": gr.Slider( | |
0, 1, value=0.7, label="ControlNet Tile Strength" | |
), | |
} | |
with gr.Column(): | |
image_slider = ImageSlider( | |
label="Compare Original and Processed", interactive=False | |
) | |
login_button = gr.LoginButton(scale=8) | |
process_btn = gr.Button("Run", variant="primary", size="lg") | |
process_btn.click( | |
fn=lambda _: gr.update(interactive=False, value="Processing..."), | |
inputs=[], | |
outputs=[process_btn], | |
api_name=False, | |
).then( | |
fn=process, | |
inputs=[ | |
input_image, | |
], | |
outputs=[image_slider], | |
api_name=False, | |
).then( | |
fn=lambda _: gr.update(interactive=True, value="Run"), | |
inputs=[], | |
outputs=[process_btn], | |
api_name=False, | |
) | |
examples = [build_example(img) for img in glob.glob("images/inputs/*")] | |
gr.Examples(examples=examples, inputs=[input_image, *params.values(), image_slider]) | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True, share=True) | |