Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import logging | |
from pathlib import Path | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
import model_utils | |
from models.SSN import SSN | |
config_file = 'configs/SSN.yaml' | |
weight = 'weights/SSN/0000001760.pt' | |
device = torch.device('cuda:0') | |
device = torch.device('cpu') | |
model = model_utils.load_model(config_file, weight, SSN, device) | |
DEFAULT_INTENSITY = 0.9 | |
DEFAULT_GAMMA = 2.0 | |
logging.info('Model loading succeed') | |
cur_rgba = None | |
cur_shadow = None | |
cur_intensity = DEFAULT_INTENSITY | |
cur_gamma = DEFAULT_GAMMA | |
def resize(img, size): | |
h, w = img.shape[:2] | |
if h > w: | |
newh = size | |
neww = int(w / h * size) | |
else: | |
neww = size | |
newh = int(h / w * size) | |
resized_img = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA) | |
if len(img.shape) != len(resized_img.shape): | |
resized_img = resized_img[..., none] | |
return resized_img | |
def ibl_normalize(ibl, energy=30.0): | |
total_energy = np.sum(ibl) | |
if total_energy < 1e-3: | |
# print('small energy: ', total_energy) | |
h,w = ibl.shape | |
return np.zeros((h,w)) | |
return ibl * energy / total_energy | |
def padding_mask(rgba_input: np.array): | |
""" Padding the mask input so that it fits the training dataset view range | |
If the rgba does not have enough padding area, we need to pad the area | |
:param rgba_input: H x W x 4 inputs, the first 3 channels are RGB, the last channel is the alpha | |
:returns: H x W x 4 padded RGBAD | |
""" | |
padding = 40 | |
padding_size = 256 - padding * 2 | |
h, w = rgba_input.shape[:2] | |
rgb = rgba_input[:, :, :3] | |
alpha = rgba_input[:, :, -1:] | |
zeros = np.where(alpha==0) | |
hh, ww = zeros[0], zeros[1] | |
h_min, h_max = hh.min(), hh.max() | |
w_min, w_max = ww.min(), ww.max() | |
# if the area already has enough padding | |
if h_max - h_min < padding_size and w_max - w_min < padding_size: | |
return rgba_input | |
padding_output = np.zeros((256, 256, 4)) | |
padding_output[..., :3] = 1.0 | |
padded_rgba = resize(rgba_input, padding_size) | |
new_h, new_w = padded_rgba.shape[:2] | |
padding_h = (256 - new_h) // 2 | |
padding_w = (256 - new_w) // 2 | |
padding_output[padding_h:padding_h+new_h, padding_w:padding_w+new_w, :] = padded_rgba | |
padding_output = np.clip(padding_output, 0.0, 1.0) | |
return padding_output | |
def shadow_composite(rgba, shadow, intensity, gamma): | |
rgb = rgba[..., :3] | |
mask = rgba[..., 3:] | |
if len(shadow.shape) == 2: | |
shadow = shadow[..., None] | |
new_shadow = 1.0 - shadow ** gamma * intensity | |
ret = rgb * mask + (1.0 - mask) * new_shadow | |
return ret, new_shadow[..., 0] | |
def render_btn_fn(mask, ibl): | |
global cur_rgba, cur_shadow, cur_gamma, cur_intensity | |
print("Button clicked!") | |
mask = mask / 255.0 | |
ibl = ibl/ 255.0 | |
mask = np.clip(mask, 0.0, 1.0) | |
# smoothing ibl | |
ibl = cv2.GaussianBlur(ibl, (11, 11), 0) | |
# padding mask | |
mask = padding_mask(mask) | |
cur_rgba = np.copy(mask) | |
print('mask shape: {}/{}/{}/{}, ibl shape: {}/{}/{}/{}'.format(mask.shape, mask.dtype, mask.min(), mask.max(), | |
ibl.shape, ibl.dtype, ibl.min(), ibl.max())) | |
# ret = np.random.randn(256, 256, 3) | |
# ret = (ret - ret.min()) / (ret.max() - ret.min() + 1e-8) | |
rgb, mask = mask[..., :3], mask[..., 3] | |
ibl = ibl_normalize(cv2.resize(ibl, (32, 16))) | |
# ibl = 1.0 - ibl | |
x = { | |
'mask': mask, | |
'ibl': ibl | |
} | |
shadow = model.inference(x) | |
cur_shadow = np.copy(shadow) | |
ret, shadow = shadow_composite(cur_rgba, shadow, cur_intensity, cur_gamma) | |
# print('IBL range: {}/{} Shadow range: {} {}'.format(ibl.min(), ibl.max(), shadow.min(), shadow.max())) | |
return ret, shadow | |
def intensity_change(x): | |
global cur_rgba, cur_shadow, cur_gamma, cur_intensity | |
cur_intensity = x | |
ret, shadow = shadow_composite(cur_rgba, cur_shadow, cur_intensity, cur_gamma) | |
return ret, shadow | |
def gamma_change(x): | |
global cur_rgba, cur_shadow, cur_gamma, cur_intensity | |
cur_gamma = x | |
ret, shadow = shadow_composite(cur_rgba, cur_shadow, cur_intensity, cur_gamma) | |
return ret, shadow | |
def update_input(mask): | |
return mask | |
ibl_h = 128 | |
ibl_w = ibl_h * 2 | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
mask_input = gr.Image(shape=None, width=256, height=256,image_mode="RGBA", label="RGBA") | |
ibl_input = gr.Sketchpad(shape=(ibl_w, ibl_h), image_mode="L", label="IBL", tool='sketch', invert_colors=True) | |
output = gr.Image(shape=(256, 256), height=256, width=256, image_mode="RGB", label="Output") | |
shadow_output = gr.Image(shape=(256, 256), height=256, width=256, image_mode="L", label="Shadow Layer") | |
with gr.Row(): | |
intensity_slider = gr.Slider(0.0, 1.0, value=DEFAULT_INTENSITY, step=0.1, label="Intensity", info="Choose between 0.0 and 1.0") | |
gamma_slider = gr.Slider(1.0, 4.0, value=DEFAULT_GAMMA, step=0.1, label="Gamma", info="Gamma correction for shadow") | |
render_btn = gr.Button(label="Render") | |
with gr.Row(): | |
gr.Examples( | |
examples=[['imgs/woman.png'],['imgs/man.png'], ['imgs/plant1.png'], ['imgs/human2.png'], ['imgs/cloud.png']], | |
fn=update_input, | |
inputs=[mask_input], | |
outputs=mask_input | |
) | |
render_btn.click(render_btn_fn, inputs=[mask_input, ibl_input], outputs=[output, shadow_output]) | |
intensity_slider.release(intensity_change, inputs=[intensity_slider], outputs=[output, shadow_output]) | |
gamma_slider.release(gamma_change, inputs=[gamma_slider], outputs=[output, shadow_output]) | |
logging.info('Finished') | |
demo.launch() | |