import gradio as gr from io import BytesIO import requests import PIL from PIL import Image import numpy as np import os import uuid import torch from torch import autocast import cv2 from matplotlib import pyplot as plt from torchvision import transforms from diffusers import DiffusionPipeline device = "cuda" if torch.cuda.is_available() else "cpu" pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", dtype=torch.float16, revision="fp16") transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((512, 512)), ]) def predict(dict, prompt=""): with autocast("cuda"): init_image = dict["image"].convert("RGB").resize((512, 512)) mask = dict["mask"].convert("RGB").resize((512, 512)) output = pipe(prompt = prompt, image=init_image, mask_image=mask, strength=0.8,num_inference_steps=20) return output.images[0] examples = [[dict(image="init_image.png", mask="mask_image.png"), "A panda sitting on a bench"]] css = ''' .container {max-width: 1150px;margin: auto;padding-top: 1.5rem} #image_upload{min-height:400px} #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px} #mask_radio .gr-form{background:transparent; border: none} #word_mask{margin-top: .75em !important} #word_mask textarea:disabled{opacity: 0.3} .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5} .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white} .dark .footer {border-color: #303030} .dark .footer>p {background: #0b0f19} .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%} #image_upload .touch-none{display: flex} ''' image_blocks = gr.Blocks(css=css) with image_blocks as demo: gr.HTML(read_content("header.html")) with gr.Group(): with gr.Box(): image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400) with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): prompt = gr.Textbox(label = 'Your prompt (what you want to replace)') btn = gr.Button("Generate image").style( margin=False, rounded=(False, True, True, False), full_width=False, ) ex = gr.Examples(fn=predict, inputs=[image, prompt], outputs=image, cache_examples=False) ex.dataset.headers = [""] btn.click(fn=predict, inputs=[image, prompt], outputs=image) gr.HTML( """