|
import logging |
|
import os |
|
import time |
|
|
|
import cv2 |
|
from diffusers import StableDiffusionPipeline |
|
import gradio as gr |
|
|
|
import numpy as np |
|
import PIL |
|
import torch.cuda |
|
from transformers import pipeline |
|
|
|
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
force=True) |
|
|
|
LOG = logging.getLogger(__name__) |
|
|
|
LOG.info("Loading image segmentation model") |
|
|
|
seg_kwargs = { |
|
"task": "image-segmentation", |
|
"model": "nvidia/segformer-b0-finetuned-ade-512-512" |
|
} |
|
|
|
img_segmentation_model = pipeline(**seg_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
LOG.info("Loading diffusion model") |
|
|
|
diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") |
|
|
|
if torch.cuda.is_available(): |
|
LOG.info("Moving diffusion model to GPU") |
|
diffusion.to('cuda') |
|
|
|
|
|
def image_preprocess(image: PIL.Image): |
|
LOG.info("Preprocessing image %s", image) |
|
start = time.time() |
|
|
|
image = image.convert("RGB") |
|
image = resize_image(image) |
|
|
|
|
|
|
|
elapsed = time.time() - start |
|
LOG.info("Image preprocessed, %.2f seconds elapsed", elapsed) |
|
return image |
|
|
|
|
|
def resize_image(image: PIL.Image): |
|
width, height = image.size |
|
ratio = max(width / 512, height / 512) |
|
width = int(width / ratio) // 8 * 8 |
|
height = int(height / ratio) // 8 * 8 |
|
image = image.resize((width, height)) |
|
return image |
|
|
|
|
|
def extract_selfie_mask(threshold, image): |
|
LOG.info("Extracting selfie mask") |
|
start = time.time() |
|
segments = img_segmentation_model(image) |
|
kept = None |
|
for s in segments: |
|
if s['score'] is None: |
|
s['score'] = 1 |
|
if s['label'] == 'person' and s['score'] > 0.99: |
|
if not kept: |
|
kept = s |
|
elif kept['score'] < s['score']: |
|
kept = s |
|
if not kept: |
|
LOG.info("No person found in the photo, skipping") |
|
mask = np.zeros((image.size[1], image.size[0], 3), dtype='float32') |
|
else: |
|
mask = kept['mask'] |
|
mask = np.array(mask, dtype='float32') |
|
|
|
cv2.threshold(mask, threshold, 1, cv2.THRESH_BINARY, dst=mask) |
|
cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1, dst=mask) |
|
cv2.blur(mask, (10, 10), dst=mask) |
|
|
|
elapsed = time.time() - start |
|
LOG.info("Selfie extracted, %.2f seconds elapsed", elapsed) |
|
return mask |
|
|
|
|
|
def generate_background(prompt, num_inference_steps, height, width): |
|
LOG.info("Generating background") |
|
start = time.time() |
|
background = diffusion( |
|
prompt=prompt, |
|
num_inference_steps=int(num_inference_steps), |
|
height=height, |
|
width=width |
|
) |
|
nsfw = background.nsfw_content_detected[0] |
|
background = background.images[0] |
|
|
|
if nsfw: |
|
LOG.info('NSFW detected, skipping') |
|
background = np.zeros((height, width, 3), dtype='uint8') |
|
else: |
|
background = np.array(background) |
|
|
|
background = background[:, :, ::-1].copy() |
|
|
|
elapsed = time.time() - start |
|
LOG.info("Background generated, elapsed %.2f seconds", elapsed) |
|
return background |
|
|
|
|
|
def merge_selfie_and_background(selfie, background, mask): |
|
LOG.info("Merging extracted selfie and generated background") |
|
selfie = np.array(selfie) |
|
|
|
selfie = selfie[:, :, ::-1].copy() |
|
cv2.blendLinear(selfie, background, mask, 1 - mask, dst=selfie) |
|
selfie = cv2.cvtColor(selfie, cv2.COLOR_BGR2RGB) |
|
selfie = PIL.Image.fromarray(selfie) |
|
return selfie |
|
|
|
|
|
def demo(threshold, image, prompt, num_inference_steps): |
|
LOG.info("Processing image") |
|
try: |
|
image = image_preprocess(image) |
|
mask = extract_selfie_mask(threshold, image) |
|
background = generate_background(prompt, num_inference_steps, |
|
image.size[1], image.size[0]) |
|
output = merge_selfie_and_background(image, background, mask) |
|
except Exception as e: |
|
LOG.error("Some unexpected error occured") |
|
LOG.exception(e) |
|
raise |
|
return output |
|
|
|
|
|
iface = gr.Interface( |
|
fn=demo, |
|
inputs=[ |
|
gr.Slider(minimum=0.1, maximum=1, step=0.05, label="Selfie segmentation threshold", |
|
value=0.8), |
|
gr.Image(type='pil', label="Upload your selfie"), |
|
gr.Text(value="a photo of the Eiffel tower on the right side", |
|
label="Background description"), |
|
gr.Slider(minimum=5, maximum=100, step=5, label="Diffusion inference steps", |
|
value=50) |
|
], |
|
outputs=[ |
|
gr.Image(label="Invent yourself a life :)") |
|
]) |
|
|
|
|
|
iface.launch() |
|
|