|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
from safetensors.torch import load_file
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
class Upscaler(nn.Module):
|
|
"""
|
|
Basic NN layout, ported from:
|
|
https://github.com/city96/SD-Latent-Upscaler/blob/main/upscaler.py
|
|
"""
|
|
version = 2.1
|
|
def head(self):
|
|
return [
|
|
nn.Conv2d(self.chan, self.size, kernel_size=self.krn, padding=self.pad),
|
|
nn.ReLU(),
|
|
nn.Upsample(scale_factor=self.fac, mode="nearest"),
|
|
nn.ReLU(),
|
|
]
|
|
def core(self):
|
|
layers = []
|
|
for _ in range(self.depth):
|
|
layers += [
|
|
nn.Conv2d(self.size, self.size, kernel_size=self.krn, padding=self.pad),
|
|
nn.ReLU(),
|
|
]
|
|
return layers
|
|
def tail(self):
|
|
return [
|
|
nn.Conv2d(self.size, self.chan, kernel_size=self.krn, padding=self.pad),
|
|
]
|
|
|
|
def __init__(self, fac, depth=16):
|
|
super().__init__()
|
|
self.size = 64
|
|
self.chan = 4
|
|
self.depth = depth
|
|
self.fac = fac
|
|
self.krn = 3
|
|
self.pad = 1
|
|
|
|
self.sequential = nn.Sequential(
|
|
*self.head(),
|
|
*self.core(),
|
|
*self.tail(),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.sequential(x)
|
|
|
|
|
|
class LatentUpscaler:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"samples": ("LATENT", ),
|
|
"latent_ver": (["v1", "xl"],),
|
|
"scale_factor": (["1.25", "1.5", "2.0"],),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("LATENT",)
|
|
FUNCTION = "upscale"
|
|
CATEGORY = "latent"
|
|
|
|
def upscale(self, samples, latent_ver, scale_factor):
|
|
model = Upscaler(scale_factor)
|
|
filename = f"latent-upscaler-v{model.version}_SD{latent_ver}-x{scale_factor}.safetensors"
|
|
local = os.path.join(
|
|
os.path.join(os.path.dirname(os.path.realpath(__file__)),"models"),
|
|
filename
|
|
)
|
|
|
|
if os.path.isfile(local):
|
|
print("LatentUpscaler: Using local model")
|
|
weights = local
|
|
else:
|
|
print("LatentUpscaler: Using HF Hub model")
|
|
weights = str(hf_hub_download(
|
|
repo_id="city96/SD-Latent-Upscaler",
|
|
filename=filename)
|
|
)
|
|
|
|
model.load_state_dict(load_file(weights))
|
|
lt = samples["samples"]
|
|
lt = model(lt)
|
|
del model
|
|
if "noise_mask" in samples.keys():
|
|
|
|
mask = torch.nn.functional.interpolate(samples['noise_mask'], scale_factor=float(scale_factor), mode='bicubic')
|
|
return ({"samples": lt, "noise_mask": mask},)
|
|
return ({"samples": lt},)
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"LatentUpscaler": LatentUpscaler,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"LatentUpscaler": "Latent Upscaler"
|
|
}
|
|
|