dikdimon's picture
Upload extensions using SD-Hub extension
c336648 verified
raw
history blame
1.27 kB
import torch
from diffusers import AutoencoderKL
def get_vae(version, file_path=None, fp16=False):
"""Load VAE from file or default hf repo. fp16 only works from hf"""
vae = None
dtype = torch.float16 if fp16 else torch.float32
if version == "v1" and file_path:
vae = AutoencoderKL.from_single_file(
file_path,
image_size=512,
)
elif version == "v1":
vae = AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="vae",
torch_dtype=dtype,
)
elif version == "v2" and file_path:
vae = AutoencoderKL.from_single_file(
file_path,
image_size=768,
)
elif version == "v2":
vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-2-1",
subfolder="vae",
torch_dtype=dtype,
)
elif version == "xl" and file_path:
vae = AutoencoderKL.from_single_file(
file_path,
image_size=1024
)
elif version == "xl" and fp16:
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
elif version == "xl":
vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="vae"
)
else:
input("Invalid VAE version. Press any key to exit")
exit(1)
return vae