dikdimon's picture
Upload extensions using SD-Hub extension
c336648 verified
raw
history blame
3.92 kB
import torch
from diffusers import AutoencoderKL
DTYPE = torch.float16
DEVICE = "cuda:0"
class SDv1_VAE:
scale = 1/8
channels = 4
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
self.model = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
def encode(self, image):
image = image.to(self.dtype).to(self.device)
image = (image * 2.0) - 1.0 # assuming input is [0;1]
with torch.no_grad():
latent = self.model.encode(image).latent_dist.sample()
return latent.to(image.dtype).to(image.device)
def decode(self, latent, grad=False):
latent = latent.to(self.dtype).to(self.device)
if grad:
out = self.model.decode(latent)[0]
else:
with torch.no_grad():
out = self.model.decode(latent).sample
out = torch.clamp(out, min=-1.0, max=1.0)
out = (out + 1.0) / 2.0
return out.to(latent.dtype).to(latent.device)
class SDXL_VAE(SDv1_VAE):
scale = 1/8
channels = 4
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
self.model = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
class SDv3_VAE(SDv1_VAE):
scale = 1/8
channels = 16
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
self.model = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="vae"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
class CascadeC_VAE(SDv1_VAE):
scale = 1/32
channels = 16
def __init__(self, device=DEVICE, dtype=DTYPE, **kwargs):
self.device = device
self.dtype = dtype
#For now this is just piggybacking off of koyha-ss/sd-scripts
from library import stable_cascade as sc
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
self.model = sc.EfficientNetEncoder()
self.model.load_state_dict(load_file(
str(hf_hub_download(
repo_id = "stabilityai/stable-cascade",
filename = "effnet_encoder.safetensors",
))
))
self.model.eval().to(self.dtype).to(self.device)
class CascadeA_VAE():
scale = 1/4
channels = 4
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
# not sure if this will change in the future?
from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel
self.model = PaellaVQModel.from_pretrained(
"stabilityai/stable-cascade",
subfolder="vqgan"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
def encode(self, image):
image = image.to(self.dtype).to(self.device)
with torch.no_grad():
latent = self.model.encode(image).latents
return latent.to(image.dtype).to(image.device)
def decode(self, latent, grad=False):
latent = latent.to(self.dtype).to(self.device)
if grad:
out = self.model.decode(latent)[0]
else:
with torch.no_grad():
out = self.model.decode(latent).sample
out = torch.clamp(out, min=0.0, max=1.0)
return out.to(latent.dtype).to(latent.device)
class No_VAE():
scale = 1
channels = 3
def __init__(self, *args, **kwargs):
pass
def encode(self, image):
return image
def decode(self, image):
return image
vae_vers = {
"no": No_VAE,
"v1": SDv1_VAE,
"xl": SDXL_VAE,
"v3": SDv3_VAE,
"cc": CascadeC_VAE,
"ca": CascadeA_VAE,
}
def load_vae(ver, *args, **kwargs):
assert ver in vae_vers.keys(), f"Unknown VAE '{ver}'"
vae_class = vae_vers[ver]
return vae_class(*args, **kwargs)