|
import os
|
|
import torch
|
|
import hashlib
|
|
import argparse
|
|
import numpy as np
|
|
from torchvision import transforms
|
|
from diffusers import AutoencoderKL
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
|
|
from vae import get_vae
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Preprocess images into latents")
|
|
parser.add_argument("-r", "--res", type=int, default=512, help="Source resolution")
|
|
parser.add_argument("-f", "--fac", type=float, default=1.5, help="Upscale factor")
|
|
parser.add_argument("-v", "--ver", choices=["v1","xl"], default="v1", help="SD version")
|
|
parser.add_argument('--vae', help="Path to VAE (Optional)")
|
|
parser.add_argument('--src', default="raw", help="Source folder with images")
|
|
return parser.parse_args()
|
|
|
|
def encode(vae, img):
|
|
"""image [PIL Image] -> latent [np array]"""
|
|
inp = transforms.ToTensor()(img).unsqueeze(0)
|
|
inp = inp.to("cuda")
|
|
latent = vae.encode(inp*2.0-1.0)
|
|
latent = latent.latent_dist.sample()
|
|
return latent.cpu().detach()
|
|
|
|
def scale(path, res):
|
|
"""Crop image to the top-left corner"""
|
|
img = Image.open(path)
|
|
img = img.convert('RGB')
|
|
target = (res, res)
|
|
if min(img.height, img.width) < 256:
|
|
return
|
|
if img.width > img.height:
|
|
target = (int(img.width/img.height*res), res)
|
|
elif img.height > img.width:
|
|
target = (res, int(img.height/img.width*res))
|
|
img = img.resize(target, Image.LANCZOS)
|
|
img = img.crop([0,0,res,res])
|
|
return img
|
|
|
|
def process_folder(vae, src_dir, ver, res):
|
|
dst_dir = f"latents/{ver}_{res}px"
|
|
if not os.path.isdir(dst_dir):
|
|
os.mkdir(dst_dir)
|
|
|
|
for file in tqdm(os.listdir(src_dir)):
|
|
src = os.path.join(src_dir, file)
|
|
md5 = hashlib.md5(open(src,'rb').read()).hexdigest()
|
|
dst = os.path.join(dst_dir, f"{md5}.npy")
|
|
if os.path.isfile(dst):
|
|
continue
|
|
img = scale(src, res)
|
|
latent = encode(vae, img)
|
|
np.save(dst, latent)
|
|
|
|
def process_res(vae, src_dir, ver, res):
|
|
process_folder(vae, src_dir, ver, res)
|
|
|
|
if os.path.isfile("test.png"):
|
|
if os.path.isfile(f"test_{ver}_{res}px.npy"):
|
|
return
|
|
img = scale("test.png", res)
|
|
latent = encode(vae, img)
|
|
np.save(f"test_{ver}_{res}px.npy", latent)
|
|
torch.cuda.empty_cache()
|
|
|
|
if __name__ == "__main__":
|
|
if not os.path.isdir("latents"):
|
|
os.mkdir("latents")
|
|
args = parse_args()
|
|
vae = get_vae(args.ver, args.vae)
|
|
vae.to("cuda")
|
|
|
|
dst_res = int(args.res*args.fac)
|
|
process_res(vae, args.src, args.ver, args.res)
|
|
process_res(vae, args.src, args.ver, dst_res)
|
|
|