Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from PIL import Image | |
from diffusers.models import AutoencoderKL | |
def main(args): | |
# Setup PyTorch: | |
torch.manual_seed(args.seed) | |
torch.set_grad_enabled(False) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# create and load model | |
vae = AutoencoderKL.from_pretrained(f"stabilityai/{args.vae}").to(device) | |
# load image | |
img_path = args.image_path | |
out_path = args.image_path.replace('.jpg', '_vae.jpg').replace('.jpeg', '_vae.jpeg').replace('.png', '_vae.png') | |
input_size = args.image_size | |
img = Image.open(img_path).convert("RGB") | |
# preprocess | |
size_org = img.size | |
img = img.resize((input_size, input_size)) | |
img = np.array(img) / 255. | |
x = 2.0 * img - 1.0 # x value is between [-1, 1] | |
x = torch.tensor(x) | |
x = x.unsqueeze(dim=0) | |
x = torch.einsum('nhwc->nchw', x) | |
x_input = x.float().to("cuda") | |
# inference | |
with torch.no_grad(): | |
# Map input images to latent space + normalize latents: | |
latent = vae.encode(x_input).latent_dist.sample().mul_(0.18215) | |
# reconstruct: | |
output = vae.decode(latent / 0.18215).sample # output value is between [-1, 1] | |
# postprocess | |
output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0] | |
sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() | |
# save | |
Image.fromarray(sample).save(out_path) | |
print("Reconstructed image is saved to {}".format(out_path)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--image-path", type=str, default="assets/example.jpg") | |
parser.add_argument("--vae", type=str, choices=["sdxl-vae", "sd-vae-ft-mse"], default="sd-vae-ft-mse") | |
parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512) | |
parser.add_argument("--seed", type=int, default=0) | |
args = parser.parse_args() | |
main(args) |