from PIL import Image import numpy as np from rembg import remove import cv2 import os from torchvision.transforms import GaussianBlur import gradio as gr import replicate import requests from io import BytesIO def create_mask(input): input_path = 'input.png' bg_removed_path = 'bg_removed.png' mask_name = 'blured_mask.png' input.save(input_path) bg_removed = remove(input) bg_removed = bg_removed.resize((512, 512)) bg_removed.save(bg_removed_path) img2_grayscale = bg_removed.convert('L') img2_a = np.array(img2_grayscale) mask = np.array(img2_grayscale) threshhold = 0 mask[img2_a==threshhold] = 1 mask[img2_a>threshhold] = 0 strength = 1 d = int(255 * (1-strength)) mask *= 255-d mask += d mask = Image.fromarray(mask) blur = GaussianBlur(11,20) mask = blur(mask) mask = mask.resize((512, 512)) mask.save(mask_name) return Image.open(mask_name) def generate_image(image, product_name, target_name): mask = create_mask(image) image = image.resize((512, 512)) mask = mask.resize((512,512)) guidance_scale=16 num_samples = 1 prompt = 'a product photography photo of' + product_name + ' on ' + target_name + 'high contrast, film photography, film grain, single light, no dof, soft light, caustic, strange pattern, neo dada style, analog led strip lighting, 190mm lens, grainy picture' model = replicate.models.get("cjwbw/stable-diffusion-v2-inpainting") version = model.versions.get("f9bb0632bfdceb83196e85521b9b55895f8ff3d1d3b487fd1973210c0eb30bec") output = version.predict(prompt=prompt, image=open("bg_removed.png", "rb"), mask=open("blured_mask.png", "rb")) response = requests.get(output[0]) return Image.open(BytesIO(response.content)) with gr.Blocks() as demo: gr.Markdown("# Advertise better with AI") # with gr.Tab("Prompt Paint - Basic"): with gr.Row(): with gr.Column(): input_image = gr.Image(label = "Upload your product's photo", type = 'pil') product_name = gr.Textbox(label="Describe your product") target_name = gr.Textbox(label="Where do you want to put your product?") # result_prompt = product_name + ' in ' + target_name + 'product photograpy ultrarealist' image_button = gr.Button("Generate") with gr.Column(): image_output = gr.Image() image_button.click(generate_image, inputs=[input_image, product_name, target_name ], outputs=image_output, api_name='test') demo.launch()