Spaces:
Paused
Paused
import gradio as gr | |
from gradio_imageslider import ImageSlider | |
import os | |
from comfydeploy import ComfyDeploy | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
from dotenv import load_dotenv | |
import base64 | |
from typing import Optional, Tuple, Union | |
import glob | |
load_dotenv() | |
# Initialize ComfyDeploy client | |
client: ComfyDeploy = ComfyDeploy(bearer_auth=os.environ.get("COMFY_DEPLOY_API_KEY")) | |
deployment_id: str = os.environ.get("COMFY_DEPLOYMENT_ID") | |
# Add these global variables at the top of the file, after imports | |
global_input_image = None | |
global_image_slider = None | |
def clear_output(): | |
return None | |
def process_image( | |
image: Optional[Union[str, Image.Image]], | |
denoise: float, | |
steps: int, | |
tile_size: int, | |
downscale: float, | |
upscale: float, | |
color_match: float, | |
controlnet_tile_end: float, | |
controlnet_tile_strength: float, | |
) -> Tuple[Optional[Image.Image], Optional[Image.Image]]: | |
# Convert image to base64 | |
if image is not None: | |
if isinstance(image, str): | |
with open(image, "rb") as img_file: | |
image_base64: str = base64.b64encode(img_file.read()).decode("utf-8") | |
else: | |
buffered: BytesIO = BytesIO() | |
image.save(buffered, format="PNG") | |
image_base64: str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
else: | |
return None, None | |
# Prepare inputs | |
inputs: dict = { | |
"image": f"data:image/png;base64,{image_base64}", | |
"denoise": str(denoise), | |
"steps": str(steps), | |
"tile_size": str(tile_size), | |
"downscale": str(downscale), | |
"upscale": str(upscale), | |
"color_match": str(color_match), | |
"controlnet_tile_end": str(controlnet_tile_end), | |
"controlnet_tile_strength": str(controlnet_tile_strength), | |
} | |
# Call ComfyDeploy API | |
try: | |
result = client.run.create( | |
request={"deployment_id": deployment_id, "inputs": inputs} | |
) | |
if result and result.object: | |
run_id: str = result.object.run_id | |
# Wait for the result | |
while True: | |
run_result = client.run.get(run_id=run_id) | |
if run_result.object.status == "success": | |
for output in run_result.object.outputs: | |
if output.data and output.data.images: | |
image_url: str = output.data.images[0].url | |
# Download and return both the original and processed images | |
response: requests.Response = requests.get(image_url) | |
processed_image: Image.Image = Image.open( | |
BytesIO(response.content) | |
) | |
return image, processed_image | |
return None, None | |
elif run_result.object.status == "failed": | |
return None, None | |
except Exception as e: | |
print(f"Error: {e}") | |
return None, None | |
def run( | |
denoise, | |
steps, | |
tile_size, | |
downscale, | |
upscale, | |
color_match, | |
controlnet_tile_end, | |
controlnet_tile_strength, | |
): | |
global global_input_image | |
global global_image_slider | |
if not global_input_image: | |
return None | |
# Set image_slider to None before processing | |
global_image_slider = None | |
# Process the image | |
original, processed = process_image( | |
global_input_image, | |
denoise, | |
steps, | |
tile_size, | |
downscale, | |
upscale, | |
color_match, | |
controlnet_tile_end, | |
controlnet_tile_strength, | |
) | |
if original and processed: | |
global_image_slider = [original, processed] | |
return global_image_slider | |
# Function to load preset images | |
def load_preset_images(): | |
image_files = glob.glob("images/inputs/*") | |
return [ | |
{"name": img, "image": Image.open(img)} | |
for img in image_files | |
if Image.open(img).format.lower() | |
in ["png", "jpg", "jpeg", "gif", "bmp", "webp"] | |
] | |
def set_input_image(images, evt: gr.SelectData): | |
global global_input_image | |
global_input_image = images[evt.index][0] | |
return global_input_image | |
# Define Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# π Creative Image Upscaler") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
type="pil", | |
label="Input Image", | |
value=lambda: global_input_image, | |
interactive=True, | |
) | |
# Add preset images | |
gr.Markdown("### Preset Images") | |
preset_images = load_preset_images() | |
gallery = gr.Gallery( | |
[img["image"] for img in preset_images], | |
label="Preset Images", | |
columns=5, | |
height=130, | |
allow_preview=False, | |
) | |
gallery.select(set_input_image, gallery, input_image) | |
with gr.Accordion("Advanced Parameters", open=False): | |
denoise: gr.Slider = gr.Slider(0, 1, value=0.4, label="Denoise") | |
steps: gr.Slider = gr.Slider(1, 40, value=10, step=1, label="Steps") | |
tile_size: gr.Slider = gr.Slider( | |
64, 2048, value=1024, step=8, label="Tile Size" | |
) | |
downscale: gr.Slider = gr.Slider( | |
1, 4, value=1, step=1, label="Downscale" | |
) | |
upscale: gr.Slider = gr.Slider(1, 4, value=4, step=0.1, label="Upscale") | |
color_match: gr.Slider = gr.Slider(0, 1, value=0, label="Color Match") | |
controlnet_tile_end: gr.Slider = gr.Slider( | |
0, 1, value=1, label="ControlNet Tile End" | |
) | |
controlnet_tile_strength: gr.Slider = gr.Slider( | |
0, 1, value=0.7, label="ControlNet Tile Strength" | |
) | |
with gr.Column(): | |
image_slider = ImageSlider( | |
label="Compare Original and Processed", | |
type="pil", | |
value=lambda: global_image_slider, | |
interactive=True, | |
) | |
process_btn: gr.Button = gr.Button("Run") | |
process_btn.click( | |
fn=run, | |
inputs=[ | |
denoise, | |
steps, | |
tile_size, | |
downscale, | |
upscale, | |
color_match, | |
controlnet_tile_end, | |
controlnet_tile_strength, | |
], | |
outputs=[image_slider], | |
) | |
def build_example(input_image_path): | |
output_image_path = input_image_path.replace("inputs", "outputs") | |
return [ | |
input_image_path, | |
0.4, # denoise | |
10, # steps | |
1024, # tile_size | |
1, # downscale | |
4, # upscale | |
0, # color_match | |
1, # controlnet_tile_end | |
0.7, # controlnet_tile_strength | |
(input_image_path, output_image_path), | |
] | |
# Build examples | |
input_images = glob.glob("images/inputs/*") | |
examples = [build_example(img) for img in input_images] | |
# Update the gr.Examples call | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
input_image, | |
denoise, | |
steps, | |
tile_size, | |
downscale, | |
upscale, | |
color_match, | |
controlnet_tile_end, | |
controlnet_tile_strength, | |
image_slider, | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True, share=True) | |