File size: 5,261 Bytes
34fb220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be2809f
34fb220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be2809f
34fb220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import torch
from torch import nn
import logging

from pathlib import Path
import gradio as gr
import numpy as np
import cv2

import model_utils
from models.SSN import SSN

config_file = 'configs/SSN.yaml'
weight      = 'weights/SSN/0000001760.pt'
device      = torch.device('cuda:0')
device      = torch.device('cpu')
model       = model_utils.load_model(config_file, weight, SSN, device)

DEFAULT_INTENSITY = 0.9
DEFAULT_GAMMA = 2.0 

logging.info('Model loading succeed')

cur_rgba = None
cur_shadow = None
cur_intensity = DEFAULT_INTENSITY 
cur_gamma = DEFAULT_GAMMA 

def resize(img, size):
    h, w = img.shape[:2]

    if h > w:
        newh = size
        neww = int(w / h * size)
    else:
        neww = size
        newh = int(h / w * size)

    resized_img = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA)
    if len(img.shape) != len(resized_img.shape):
        resized_img = resized_img[..., none]

    return resized_img


def ibl_normalize(ibl, energy=30.0):
    total_energy = np.sum(ibl)
    if total_energy < 1e-3:
        # print('small energy: ', total_energy)
        h,w = ibl.shape
        return np.zeros((h,w))

    return ibl * energy / total_energy


def padding_mask(rgba_input: np.array):
    """ Padding the mask input so that it fits the training dataset view range

    If the rgba does not have enough padding area, we need to pad the area

    :param rgba_input: H x W x 4 inputs, the first 3 channels are RGB, the last channel is the alpha
    :returns: H x W x 4 padded RGBAD

    """
    padding = 50
    padding_size = 256 - padding * 2

    h, w = rgba_input.shape[:2]
    rgb = rgba_input[:, :, :3]
    alpha = rgba_input[:, :, -1:]

    zeros = np.where(alpha==0)
    hh, ww = zeros[0], zeros[1]
    h_min, h_max = hh.min(), hh.max()
    w_min, w_max = ww.min(), ww.max()

    # if the area already has enough padding
    if h_max - h_min < padding_size and w_max - w_min < padding_size:
        return rgba_input

    padding_output = np.zeros((256, 256, 4))
    padding_output[..., :3] = 1.0

    padded_rgba  = resize(rgba_input, padding_size)
    new_h, new_w = padded_rgba.shape[:2]

    padding_output[padding:padding+new_h, padding:padding+new_w, :] = padded_rgba

    return padding_output

def shadow_composite(rgba, shadow, intensity, gamma):
    rgb = rgba[..., :3]
    mask = rgba[..., 3:]

    if len(shadow.shape) == 2:
        shadow = shadow[..., None]

    new_shadow = 1.0 - shadow ** gamma * intensity
    ret = rgb * mask + (1.0 - mask) * new_shadow
    return ret, new_shadow[..., 0]


def render_btn_fn(mask, ibl):
    global cur_rgba, cur_shadow, cur_gamma, cur_intensity

    print("Button clicked!")

    mask = mask / 255.0
    ibl = ibl/ 255.0

    # smoothing ibl
    ibl = cv2.GaussianBlur(ibl, (11, 11), 0)

    # padding mask
    mask = padding_mask(mask)

    cur_rgba = np.copy(mask)


    print('mask shape: {}/{}/{}/{}, ibl shape: {}/{}/{}/{}'.format(mask.shape, mask.dtype, mask.min(), mask.max(),
                                                                   ibl.shape, ibl.dtype, ibl.min(), ibl.max()))

    # ret = np.random.randn(256, 256, 3)
    # ret = (ret - ret.min()) / (ret.max() - ret.min() + 1e-8)

    rgb, mask = mask[..., :3], mask[..., 3]

    ibl = ibl_normalize(cv2.resize(ibl, (32, 16)))

    # ibl = 1.0 - ibl

    x = {
        'mask': mask,
        'ibl': ibl
    }
    shadow = model.inference(x)
    cur_shadow = np.copy(shadow)

    ret, shadow = shadow_composite(cur_rgba, shadow, cur_intensity, cur_gamma)

    # print('IBL range: {}/{} Shadow range: {} {}'.format(ibl.min(), ibl.max(), shadow.min(), shadow.max()))
    return ret, shadow


def intensity_change(x):
    global cur_rgba, cur_shadow, cur_gamma, cur_intensity

    cur_intensity = x
    ret, shadow = shadow_composite(cur_rgba, cur_shadow, cur_intensity, cur_gamma)    
    return ret, shadow


def gamma_change(x):
    global cur_rgba, cur_shadow, cur_gamma, cur_intensity

    cur_gamma = x
    ret, shadow = shadow_composite(cur_rgba, cur_shadow, cur_intensity, cur_gamma)    
    return ret, shadow


ibl_h = 128
ibl_w = ibl_h * 2

with gr.Blocks() as demo:
    with gr.Row():
        mask_input = gr.Image(shape=(256, 256), image_mode="RGBA", label="Mask")
        ibl_input = gr.Sketchpad(shape=(ibl_w, ibl_h), image_mode="L", label="IBL", tool='sketch', invert_colors=True)
        output = gr.Image(shape=(256, 256), height=256, width=256, image_mode="RGB", label="Output")
        shadow_output = gr.Image(shape=(256, 256), height=256, width=256, image_mode="L", label="Shadow Layer")

    with gr.Row():
        intensity_slider = gr.Slider(0.0, 1.0, value=DEFAULT_INTENSITY, step=0.1, label="Intensity", info="Choose between 0.0 and 1.0") 
        gamma_slider = gr.Slider(1.0, 4.0, value=DEFAULT_GAMMA, step=0.1, label="Gamma", info="Gamma correction for shadow") 
        render_btn = gr.Button(label="Render")

    render_btn.click(render_btn_fn, inputs=[mask_input, ibl_input], outputs=[output, shadow_output])
    intensity_slider.release(intensity_change, inputs=[intensity_slider], outputs=[output, shadow_output])
    gamma_slider.release(gamma_change, inputs=[gamma_slider], outputs=[output, shadow_output])

    logging.info('Finished')


demo.launch()