test / extensions /SD-Latent-Interposer /comfy_latent_interposer.py
dikdimon's picture
Upload extensions using SD-Hub extension
c336648 verified
raw
history blame
No virus
5.41 kB
import os
import torch
import torch.nn as nn
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
# v1 = Stable Diffusion 1.x
# xl = Stable Diffusion Extra Large (SDXL)
# v3 = Stable Diffusion Version Three (SD3)
# cc = Stable Cascade (Stage C) [not used]
# ca = Stable Cascade (Stage A/B)
config = {
"v1-to-xl": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"v1-to-v3": {"ch_in": 4, "ch_out":16, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"xl-to-v1": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"xl-to-v3": {"ch_in": 4, "ch_out":16, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"v3-to-v1": {"ch_in":16, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"v3-to-xl": {"ch_in":16, "ch_out": 4, "ch_mid": 64, "scale": 1.0, "blocks": 12},
"ca-to-v1": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 0.5, "blocks": 12},
"ca-to-xl": {"ch_in": 4, "ch_out": 4, "ch_mid": 64, "scale": 0.5, "blocks": 12},
"ca-to-v3": {"ch_in": 4, "ch_out":16, "ch_mid": 64, "scale": 0.5, "blocks": 12},
}
class ResBlock(nn.Module):
"""Block with residuals"""
def __init__(self, ch):
super().__init__()
self.join = nn.ReLU()
self.norm = nn.BatchNorm2d(ch)
self.long = nn.Sequential(
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1),
nn.Dropout(0.1)
)
def forward(self, x):
x = self.norm(x)
return self.join(self.long(x) + x)
class ExtractBlock(nn.Module):
"""Increase no. of channels by [out/in]"""
def __init__(self, ch_in, ch_out):
super().__init__()
self.join = nn.ReLU()
self.short = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.long = nn.Sequential(
nn.Conv2d( ch_in, ch_out, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.SiLU(),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
nn.Dropout(0.1)
)
def forward(self, x):
return self.join(self.long(x) + self.short(x))
class InterposerModel(nn.Module):
"""
NN layout, ported from:
https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
"""
def __init__(self, ch_in=4, ch_out=4, ch_mid=64, scale=1.0, blocks=12):
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.ch_mid = ch_mid
self.blocks = blocks
self.scale = scale
self.head = ExtractBlock(self.ch_in, self.ch_mid)
self.core = nn.Sequential(
nn.Upsample(scale_factor=self.scale, mode="nearest"),
*[ResBlock(self.ch_mid) for _ in range(blocks)],
nn.BatchNorm2d(self.ch_mid),
nn.SiLU(),
)
self.tail = nn.Conv2d(self.ch_mid, self.ch_out, kernel_size=3, stride=1, padding=1)
def forward(self, x):
y = self.head(x)
z = self.core(y)
return self.tail(z)
class ComfyLatentInterposer:
"""Custom node"""
def __init__(self):
self.version = 4.0 # network revision
self.loaded = None # current model name
self.model = None # current model
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT", ),
"latent_src": (["v1", "xl", "v3", "ca"],),
"latent_dst": (["v1", "xl", "v3"],),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "convert"
CATEGORY = "latent"
TITLE = "Latent Interposer"
def get_model_path(self, model_name):
fname = f"{model_name}_interposer-v{self.version}.safetensors"
path = os.path.join(os.path.dirname(os.path.realpath(__file__)),"models")
# local path: [models/xl-to-v1_interposer-v4.2.safetensors]
if os.path.isfile(os.path.join(path, fname)):
print("LatentInterposer: Using local model")
return os.path.join(path, fname)
# local path: [models/v4.2/xl-to-v1_interposer-v4.2.safetensors]
if os.path.isfile(os.path.join(path, os.path.join(f"v{self.version}", fname))):
print("LatentInterposer: Using local model")
return os.path.join(path, os.path.join(f"v{self.version}", fname))
# huggingface hub fallback
print("LatentInterposer: Using HF Hub model")
return str(hf_hub_download(
repo_id = "city96/SD-Latent-Interposer",
subfolder = f"v{self.version}",
filename = fname,
))
def convert(self, samples, latent_src, latent_dst):
samples = samples.copy()
if latent_src == latent_dst:
return (samples,)
model_name = f"{latent_src}-to-{latent_dst}"
if model_name not in config:
raise ValueError(f"No model exists for this conversion! ({model_name})")
# only reload if changed
if self.loaded != model_name or self.model is None:
# load/init model
path = self.get_model_path(model_name)
model = InterposerModel(**config[model_name])
model.eval()
model.load_state_dict(load_file(path))
# keep for later runs
self.model = model
self.loaded = model_name
lt = samples["samples"]
with torch.no_grad():
# force FP32, always run on CPU
lt = self.model(
lt.cpu().float()
).to(lt.device).to(lt.dtype)
samples["samples"] = lt
return (samples,)
NODE_CLASS_MAPPINGS = {
"LatentInterposer": ComfyLatentInterposer,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LatentInterposer": ComfyLatentInterposer.TITLE,
}