image-blender / app.py
tonyassi's picture
Update app.py
8a7c98d verified
import gradio as gr
import spaces
from diffusers import KandinskyPriorPipeline, KandinskyPipeline
from diffusers.utils import load_image
import torch
from PIL import Image
pipe_prior = KandinskyPriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
)
pipe_prior.to("cuda")
pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
pipe.to("cuda")
@spaces.GPU()
def squarify_image(img):
if(img.height > img.width): bg_size = img.height
else: bg_size = img.width
bg = Image.new(mode="RGB", size=(bg_size,bg_size), color="white")
bg.paste(img, ( int((bg.width - bg.width)/2), 0) )
return bg
@spaces.GPU()
def blend(img1, img2, slider, prompt, negative_prompt):
img1.thumbnail((1024, 1024))
img2.thumbnail((1024, 1024))
img1 = squarify_image(img1)
img2 = squarify_image(img2)
# add all the conditions we want to interpolate, can be either text or image
images_texts = [img1, img2]
# specify the weights for each condition in images_texts
weights = [1-slider, slider]
prior_out = pipe_prior.interpolate(images_texts, weights)
image = pipe(prompt=prompt, **prior_out, height=1024, width=1024, negative_prompt=negative_prompt).images[0]
return image
with gr.Blocks() as demo:
gr.Markdown("""
# Image Blender
by [Tony Assi](https://www.tonyassi.com/)
""")
with gr.Row():
with gr.Column():
img1 = gr.Image(label='Image 0', type='pil')
img2 = gr.Image(label='Image 1',type='pil')
slider = gr.Slider(label='Weight', maximum=1.0, value=0.5)
with gr.Accordion("Advanced", open=False):
prompt = gr.Textbox(label='Prompt', value='')
negative_prompt = gr.Textbox(label='Negative Prompt', value='')
btn = gr.Button("Blend")
with gr.Column():
output = gr.Image(label='Result')
gr.Examples(
[['./cat.png', './starry_night.jpg', 0.5, '', '']],
[img1, img2, slider, prompt, negative_prompt],
output,
blend,
cache_examples=True,
)
btn.click(fn=blend, inputs=[img1, img2, slider, prompt, negative_prompt], outputs=output)
demo.launch()