test / extensions /SD-Latent-Upscaler /comfy_latent_upscaler.py
dikdimon's picture
Upload extensions using SD-Hub extension
c336648 verified
raw
history blame
2.72 kB
import os
import torch
import torch.nn as nn
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
class Upscaler(nn.Module):
"""
Basic NN layout, ported from:
https://github.com/city96/SD-Latent-Upscaler/blob/main/upscaler.py
"""
version = 2.1 # network revision
def head(self):
return [
nn.Conv2d(self.chan, self.size, kernel_size=self.krn, padding=self.pad),
nn.ReLU(),
nn.Upsample(scale_factor=self.fac, mode="nearest"),
nn.ReLU(),
]
def core(self):
layers = []
for _ in range(self.depth):
layers += [
nn.Conv2d(self.size, self.size, kernel_size=self.krn, padding=self.pad),
nn.ReLU(),
]
return layers
def tail(self):
return [
nn.Conv2d(self.size, self.chan, kernel_size=self.krn, padding=self.pad),
]
def __init__(self, fac, depth=16):
super().__init__()
self.size = 64 # Conv2d size
self.chan = 4 # in/out channels
self.depth = depth # no. of layers
self.fac = fac # scale factor
self.krn = 3 # kernel size
self.pad = 1 # padding
self.sequential = nn.Sequential(
*self.head(),
*self.core(),
*self.tail(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.sequential(x)
class LatentUpscaler:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT", ),
"latent_ver": (["v1", "xl"],),
"scale_factor": (["1.25", "1.5", "2.0"],),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale"
CATEGORY = "latent"
def upscale(self, samples, latent_ver, scale_factor):
model = Upscaler(scale_factor)
filename = f"latent-upscaler-v{model.version}_SD{latent_ver}-x{scale_factor}.safetensors"
local = os.path.join(
os.path.join(os.path.dirname(os.path.realpath(__file__)),"models"),
filename
)
if os.path.isfile(local):
print("LatentUpscaler: Using local model")
weights = local
else:
print("LatentUpscaler: Using HF Hub model")
weights = str(hf_hub_download(
repo_id="city96/SD-Latent-Upscaler",
filename=filename)
)
model.load_state_dict(load_file(weights))
lt = samples["samples"]
lt = model(lt)
del model
if "noise_mask" in samples.keys():
# expand the noise mask to the same shape as the latent
mask = torch.nn.functional.interpolate(samples['noise_mask'], scale_factor=float(scale_factor), mode='bicubic')
return ({"samples": lt, "noise_mask": mask},)
return ({"samples": lt},)
NODE_CLASS_MAPPINGS = {
"LatentUpscaler": LatentUpscaler,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LatentUpscaler": "Latent Upscaler"
}