Spaces:
Runtime error
Runtime error
# %% | |
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
from networks.mat import Generator | |
import gradio as gr | |
import gradio.components as gc | |
import base64 | |
import glob | |
import os | |
import random | |
import re | |
from http import HTTPStatus | |
from io import BytesIO | |
from typing import Dict, List, NamedTuple, Optional, Tuple | |
import click | |
import cv2 | |
import numpy as np | |
import PIL.Image | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image, ImageDraw, ImageOps | |
from pydantic import BaseModel | |
import dnnlib | |
import legacy | |
pyspng = None | |
def num_range(s: str) -> List[int]: | |
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' | |
range_re = re.compile(r'^(\d+)-(\d+)$') | |
m = range_re.match(s) | |
if m: | |
return list(range(int(m.group(1)), int(m.group(2))+1)) | |
vals = s.split(',') | |
return [int(x) for x in vals] | |
def copy_params_and_buffers(src_module, dst_module, require_all=False): | |
assert isinstance(src_module, torch.nn.Module) | |
assert isinstance(dst_module, torch.nn.Module) | |
src_tensors = {name: tensor for name, | |
tensor in named_params_and_buffers(src_module)} | |
for name, tensor in named_params_and_buffers(dst_module): | |
assert (name in src_tensors) or (not require_all) | |
if name in src_tensors: | |
tensor.copy_(src_tensors[name].detach()).requires_grad_( | |
tensor.requires_grad) | |
def params_and_buffers(module): | |
assert isinstance(module, torch.nn.Module) | |
return list(module.parameters()) + list(module.buffers()) | |
def named_params_and_buffers(module): | |
assert isinstance(module, torch.nn.Module) | |
return list(module.named_parameters()) + list(module.named_buffers()) | |
class Inpainter: | |
def __init__(self, | |
network_pkl, | |
resolution=512, | |
truncation_psi=1, | |
noise_mode='const', | |
sdevice='cpu' | |
): | |
self.resolution = resolution | |
self.truncation_psi = truncation_psi | |
self.noise_mode = noise_mode | |
print(f'Loading networks from: {network_pkl}') | |
self.device = torch.device(sdevice) | |
with dnnlib.util.open_url(network_pkl) as f: | |
G_saved = ( | |
legacy.load_network_pkl(f) | |
['G_ema'] | |
.to(self.device) | |
.eval() | |
.requires_grad_(False)) # type: ignore | |
net_res = 512 if resolution > 512 else resolution | |
self.G = ( | |
Generator( | |
z_dim=512, | |
c_dim=0, | |
w_dim=512, | |
img_resolution=net_res, | |
img_channels=3 | |
) | |
.to(self.device) | |
.eval() | |
.requires_grad_(False) | |
) | |
copy_params_and_buffers(G_saved, self.G, require_all=True) | |
def generate_images2( | |
self, | |
dpath: List[PIL.Image.Image], | |
mpath: List[Optional[PIL.Image.Image]], | |
seed: int = 42, | |
): | |
""" | |
Generate images using pretrained network pickle. | |
""" | |
resolution = self.resolution | |
truncation_psi = self.truncation_psi | |
noise_mode = self.noise_mode | |
# seed = 240 # pick up a random number | |
def seed_all(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
if seed is not None: | |
seed_all(seed) | |
# no Labels. | |
label = torch.zeros([1, self.G.c_dim], device=self.device) | |
def read_image(image): | |
image = np.array(image) | |
if image.ndim == 2: | |
image = image[:, :, np.newaxis] # HW => HWC | |
image = np.repeat(image, 3, axis=2) | |
image = image.transpose(2, 0, 1) # HWC => CHW | |
image = image[:3] | |
return image | |
if resolution != 512: | |
noise_mode = 'random' | |
results = [] | |
with torch.no_grad(): | |
for i, (ipath, m) in enumerate(zip(dpath, mpath)): | |
if seed is None: | |
seed_all(i) | |
image = read_image(ipath) | |
image = (torch.from_numpy(image).float().to( | |
self. device) / 127.5 - 1).unsqueeze(0) | |
mask = np.array(m).astype(np.float32) / 255.0 | |
mask = torch.from_numpy(mask).float().to( | |
self. device).unsqueeze(0).unsqueeze(0) | |
z = torch.from_numpy(np.random.randn( | |
1, self.G.z_dim)).to(self.device) | |
output = self.G(image, mask, z, label, | |
truncation_psi=truncation_psi, noise_mode=noise_mode) | |
output = (output.permute(0, 2, 3, 1) * 127.5 + | |
127.5).round().clamp(0, 255).to(torch.uint8) | |
output = output[0].cpu().numpy() | |
results.append(PIL.Image.fromarray(output, 'RGB')) | |
return results | |
# if __name__ == "__main__": | |
# generate_images() # pylint: disable=no-value-for-parameter | |
# ---------------------------------------------------------------------------- | |
def mask_to_alpha(img, mask): | |
img = img.copy() | |
img.putalpha(mask) | |
return img | |
def blend(src, target, mask): | |
mask = np.expand_dims(mask, axis=-1) | |
result = (1-mask) * src + mask * target | |
return Image.fromarray(result.astype(np.uint8)) | |
def pad(img, size=(128, 128), tosize=(512, 512), border=1): | |
if isinstance(size, float): | |
size = (int(img.size[0] * size), int(img.size[1] * size)) | |
# remove border | |
w, h = tosize | |
new_img = Image.new('RGBA', (w, h)) | |
rimg = img.resize(size, resample=Image.Resampling.NEAREST) | |
rimg = ImageOps.crop(rimg, border=border) | |
tw, th = size | |
tw, th = tw - border*2, th - border*2 | |
tc = ((w-tw)//2, (h-th)//2) | |
new_img.paste(rimg, tc) | |
mask = Image.new('L', (w, h)) | |
white = Image.new('L', (tw, th), 255) | |
mask.paste(white, tc) | |
if 'A' in rimg.getbands(): | |
mask.paste(rimg.getchannel('A'), tc) | |
return new_img, mask | |
def b64_to_img(b64): | |
return Image.open(BytesIO(base64.b64decode(b64))) | |
def img_to_b64(img): | |
with BytesIO() as f: | |
img.save(f, format='PNG') | |
return base64.b64encode(f.getvalue()).decode('utf-8') | |
class Predictor: | |
def __init__(self): | |
"""Load the model into memory to make running multiple predictions efficient""" | |
self.models = { | |
"places2": Inpainter( | |
network_pkl='models/Places_512_FullData.pkl', | |
resolution=512, | |
truncation_psi=1., | |
noise_mode='const', | |
), | |
"places2+laion300k": Inpainter( | |
network_pkl='models/Places_512_FullData+LAION300k.pkl', | |
resolution=512, | |
truncation_psi=1., | |
noise_mode='const', | |
), | |
"places2+laion300k+laion300k(opmasked)": Inpainter( | |
network_pkl='models/Places_512_FullData+LAION300k+OPM300k.pkl', | |
resolution=512, | |
truncation_psi=1., | |
noise_mode='const', | |
), | |
"places2+laion300k+laion1200k(opmasked)": Inpainter( | |
network_pkl='models/Places_512_FullData+LAION300k+OPM1200k.pkl', | |
resolution=512, | |
truncation_psi=1., | |
noise_mode='const', | |
), | |
} | |
# The arguments and types the model takes as input | |
def predict( | |
self, | |
img: Image.Image, | |
tosize=(512, 512), | |
border=5, | |
seed=42, | |
size=0.5, | |
model='places2', | |
) -> Image: | |
i, m = pad( | |
img, | |
size=size, # (328, 328), | |
tosize=tosize, | |
border=border | |
) | |
"""Run a single prediction on the model""" | |
imgs = self.models[model].generate_images2( | |
dpath=[i.resize((512, 512), resample=Image.Resampling.NEAREST)], | |
mpath=[m.resize((512, 512), resample=Image.Resampling.NEAREST)], | |
seed=seed, | |
) | |
img_op_raw = imgs[0].convert('RGBA') | |
img_op_raw = img_op_raw.resize( | |
tosize, resample=Image.Resampling.NEAREST) | |
inpainted = img_op_raw.copy() | |
# paste original image to remove inpainting/scaling artifacts | |
inpainted = blend( | |
i, | |
inpainted, | |
1-(np.array(m) / 255) | |
) | |
minpainted = mask_to_alpha(inpainted, m) | |
return inpainted, minpainted, ImageOps.invert(m) | |
predictor = Predictor() | |
# %% | |
def _outpaint(img, tosize, border, seed, size, model): | |
img_op = predictor.predict( | |
img, | |
border=border, | |
seed=seed, | |
tosize=(tosize, tosize), | |
size=float(size), | |
model=model, | |
) | |
return img_op | |
# %% | |
with gr.Blocks() as demo: | |
maturl = 'https://github.com/fenglinglwb/MAT' | |
gr.Markdown(f''' | |
# MAT Primer for Stable Diffusion | |
## based on MAT: Mask-Aware Transformer for Large Hole Image Inpainting | |
### create a primer for use in stable diffusion outpainting | |
''') | |
gr.HTML(f'''<a href="{maturl}">{maturl}</a>''') | |
with gr.Box(): | |
with gr.Row(): | |
gr.Markdown(f"""example with strength 0.5""") | |
with gr.Row(): | |
gr.HTML("<img src='file/op.gif'> ") | |
gr.HTML("<img src='file/process.gif'>") | |
btn = gr.Button("Run", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
searchimage = gc.Image( | |
shape=(224, 224), label="image", type='pil', image_mode='RGBA') | |
to_size = gc.Slider(1, 1920, 512, step=1, label='output size') | |
border = gc.Slider( | |
1, 50, 0, step=1, label='border to crop from the image before outpainting') | |
seed = gc.Slider(1, 65536, 10, step=1, label='seed') | |
size = gc.Slider(0, 1, .5, step=0.01, | |
label='scale of the image before outpainting') | |
model = gc.Dropdown( | |
choices=['places2', | |
'places2+laion300k', | |
'places2+laion300k+laion300k(opmasked)', | |
'places2+laion300k+laion1200k(opmasked)' | |
], | |
value='places2+laion300k+laion1200k(opmasked)', | |
label='model', | |
) | |
with gr.Column(): | |
outwithoutalpha = gc.Image( | |
label="primed image without alpha channel", type='pil', image_mode='RGBA') | |
mask = gc.Image(label="outpainting mask", type='pil') | |
out = gc.Image(label="primed image with alpha channel", | |
type='pil', image_mode='RGBA') | |
btn.click( | |
fn=_outpaint, | |
inputs=[searchimage, to_size, border, seed, size, model], | |
outputs=[outwithoutalpha, out, mask]) | |
# %% launch | |
demo.launch() | |