from PIL import Image import numpy as np from rembg import remove import cv2 import os from torchvision.transforms import GaussianBlur import gradio as gr import replicate import requests from io import BytesIO def create_mask(input): input_path = 'input.png' bg_removed_path = 'bg_removed.png' mask_name = 'blured_mask.png' input.save(input_path) bg_removed = remove(input) bg_removed = bg_removed.resize((512, 512)) bg_removed.save(bg_removed_path) img2_grayscale = bg_removed.convert('L') img2_a = np.array(img2_grayscale) mask = np.array(img2_grayscale) threshhold = 0 mask[img2_a==threshhold] = 1 mask[img2_a>threshhold] = 0 strength = 1 d = int(255 * (1-strength)) mask *= 255-d mask += d mask = Image.fromarray(mask) blur = GaussianBlur(11,20) mask = blur(mask) mask = mask.resize((512, 512)) mask.save(mask_name) return Image.open(mask_name) def generate_image(image, product_name, target_name): mask = create_mask(image) image = image.resize((512, 512)) mask = mask.resize((512,512)) guidance_scale=16 num_samples = 1 prompt = 'a photo of a ' + product_name + ' with ' + target_name + ' product photograpy' model = replicate.models.get("cjwbw/stable-diffusion-v2-inpainting") version = model.versions.get("f9bb0632bfdceb83196e85521b9b55895f8ff3d1d3b487fd1973210c0eb30bec") output = version.predict(prompt=prompt, image=open("bg_removed.png", "rb"), mask=open("blured_mask.png", "rb")) response = requests.get(output[0]) return Image.open(BytesIO(response.content)) with gr.Blocks() as demo: gr.Markdown("# Advertise better with AI") # with gr.Tab("Prompt Paint - Basic"): with gr.Row(): with gr.Column(): input_image = gr.Image(label = "Upload your product's photo", type = 'pil') product_name = gr.Textbox(label="Describe your product") target_name = gr.Textbox(label="Where do you want to put your product?") # result_prompt = product_name + ' in ' + target_name + 'product photograpy ultrarealist' image_button = gr.Button("Generate") with gr.Column(): image_output = gr.Image() image_button.click(generate_image, inputs=[input_image, product_name, target_name ], outputs=image_output, api_name='test') demo.launch()