|
from PIL import Image |
|
import numpy as np |
|
from rembg import remove |
|
import cv2 |
|
import os |
|
from torchvision.transforms import GaussianBlur |
|
import gradio as gr |
|
import replicate |
|
import requests |
|
from io import BytesIO |
|
|
|
def create_mask(input): |
|
input_path = 'input.png' |
|
bg_removed_path = 'bg_removed.png' |
|
mask_name = 'blured_mask.png' |
|
|
|
input.save(input_path) |
|
bg_removed = remove(input) |
|
bg_removed = bg_removed.resize((512, 512)) |
|
bg_removed.save(bg_removed_path) |
|
|
|
img2_grayscale = bg_removed.convert('L') |
|
img2_a = np.array(img2_grayscale) |
|
|
|
mask = np.array(img2_grayscale) |
|
threshhold = 0 |
|
mask[img2_a==threshhold] = 1 |
|
mask[img2_a>threshhold] = 0 |
|
|
|
strength = 1 |
|
d = int(255 * (1-strength)) |
|
mask *= 255-d |
|
mask += d |
|
|
|
mask = Image.fromarray(mask) |
|
|
|
blur = GaussianBlur(11,20) |
|
mask = blur(mask) |
|
mask = mask.resize((512, 512)) |
|
|
|
mask.save(mask_name) |
|
|
|
return Image.open(mask_name) |
|
|
|
|
|
def generate_image(image, product_name, target_name): |
|
mask = create_mask(image) |
|
image = image.resize((512, 512)) |
|
mask = mask.resize((512,512)) |
|
guidance_scale=16 |
|
num_samples = 1 |
|
|
|
prompt = 'a photo of a ' + product_name + ' with ' + target_name + ' product photograpy' |
|
|
|
model = replicate.models.get("cjwbw/stable-diffusion-v2-inpainting") |
|
version = model.versions.get("f9bb0632bfdceb83196e85521b9b55895f8ff3d1d3b487fd1973210c0eb30bec") |
|
output = version.predict(prompt=prompt, image=open("bg_removed.png", "rb"), mask=open("blured_mask.png", "rb")) |
|
response = requests.get(output[0]) |
|
|
|
return Image.open(BytesIO(response.content)) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Advertise better with AI") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
input_image = gr.Image(label = "Upload your product's photo", type = 'pil') |
|
|
|
product_name = gr.Textbox(label="Describe your product") |
|
target_name = gr.Textbox(label="Where do you want to put your product?") |
|
|
|
|
|
image_button = gr.Button("Generate") |
|
|
|
with gr.Column(): |
|
image_output = gr.Image() |
|
|
|
image_button.click(generate_image, inputs=[input_image, product_name, target_name ], outputs=image_output, api_name='test') |
|
|
|
|
|
demo.launch() |