|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import argparse
|
|
import random
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
from safetensors.torch import save_file, load_file
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
from upscaler import LatentUpscaler as Upscaler
|
|
from vae import get_vae
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Train latent interposer model")
|
|
parser.add_argument("--steps", type=int, default=500000, help="No. of training steps")
|
|
parser.add_argument('--bs', type=int, default=4, help="Batch size")
|
|
parser.add_argument('--lr', default="5e-4", help="Learning rate")
|
|
parser.add_argument("-n", "--save_every_n", type=int, dest="save", default=50000, help="Save model/sample periodically")
|
|
parser.add_argument("-r", "--res", type=int, default=512, help="Source resolution")
|
|
parser.add_argument("-f", "--fac", type=float, default=1.5, help="Upscale factor")
|
|
parser.add_argument("-v", "--ver", choices=["v1","xl"], default="v1", help="SD version")
|
|
parser.add_argument('--vae', help="Path to VAE (Optional)")
|
|
parser.add_argument('--resume', help="Checkpoint to resume from")
|
|
args = parser.parse_args()
|
|
try:
|
|
float(args.lr)
|
|
except:
|
|
parser.error("--lr must be a valid float eg. 0.001 or 1e-3")
|
|
return args
|
|
|
|
vae = None
|
|
def sample_decode(latent, filename, version):
|
|
global vae
|
|
if not vae:
|
|
vae = get_vae(version, fp16=True)
|
|
vae.to("cuda")
|
|
|
|
latent = latent.half().to("cuda")
|
|
out = vae.decode(latent).sample
|
|
out = out.cpu().detach().numpy()
|
|
out = np.squeeze(out, 0)
|
|
out = out.transpose((1, 2, 0))
|
|
out = np.clip(out, -1.0, 1.0)
|
|
out = (out+1)/2 * 255
|
|
out = out.astype(np.uint8)
|
|
out = Image.fromarray(out)
|
|
out.save(filename)
|
|
|
|
def eval_model(step, model, criterion, scheduler, src, dst):
|
|
with torch.no_grad():
|
|
t_pred = model(src)
|
|
t_loss = criterion(t_pred, dst)
|
|
tqdm.write(f"{str(step):<10} {loss.data.item():.4e}|{t_loss.data.item():.4e} @ {float(scheduler.get_last_lr()[0]):.4e}")
|
|
log.write(f"{step},{loss.data.item()},{t_loss.data.item()},{float(scheduler.get_last_lr()[0])}\n")
|
|
log.flush()
|
|
|
|
def save_model(step, model, ver, fac, src):
|
|
out = model(src)
|
|
output_name = f"./models/latent-upscaler_SD{ver}-x{fac}_e{round(step/1000)}k"
|
|
sample_decode(out, f"{output_name}.png", ver)
|
|
save_file(model.state_dict(), f"{output_name}.safetensors")
|
|
|
|
class Latent:
|
|
def __init__(self, md5, ver, src_res, dst_res):
|
|
src = os.path.join(f"latents/{ver}_{src_res}px", f"{md5}.npy")
|
|
dst = os.path.join(f"latents/{ver}_{dst_res}px", f"{md5}.npy")
|
|
self.src = torch.from_numpy(np.load(src)).to("cuda")
|
|
self.dst = torch.from_numpy(np.load(dst)).to("cuda")
|
|
self.src = torch.squeeze(self.src, 0)
|
|
self.dst = torch.squeeze(self.dst, 0)
|
|
|
|
class LatentDataset(Dataset):
|
|
def __init__(self, ver, src_res, dst_res):
|
|
print("Loading latents from disk")
|
|
self.latents = []
|
|
for i in tqdm(os.listdir(f"latents/{ver}_{src_res}px")):
|
|
md5 = os.path.splitext(i)[0]
|
|
self.latents.append(
|
|
Latent(md5, ver, src_res, dst_res)
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.latents)
|
|
|
|
def __getitem__(self, index):
|
|
return (
|
|
self.latents[index].src,
|
|
self.latents[index].dst,
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
target_dev = "cuda"
|
|
dst_res = int(args.res*args.fac)
|
|
|
|
dataset = LatentDataset(args.ver, args.res, dst_res)
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=args.bs,
|
|
shuffle=True,
|
|
num_workers=0,
|
|
)
|
|
|
|
if not os.path.isdir("models"): os.mkdir("models")
|
|
log = open(f"models/latent-upscaler_SD{args.ver}-x{args.fac}.csv", "w")
|
|
|
|
if os.path.isfile(f"test_{args.ver}_{args.res}px.npy") and os.path.isfile(f"test_{args.ver}_{dst_res}px.npy"):
|
|
eval_src = torch.from_numpy(np.load(f"test_{args.ver}_{args.res}px.npy")).to(target_dev)
|
|
eval_dst = torch.from_numpy(np.load(f"test_{args.ver}_{dst_res}px.npy")).to(target_dev)
|
|
else:
|
|
eval_src = torch.unsqueeze(dataset[0][0],0)
|
|
eval_dst = torch.unsqueeze(dataset[0][1],0)
|
|
|
|
model = Upscaler(args.fac)
|
|
if args.resume:
|
|
model.load_state_dict(load_file(args.resume))
|
|
model.to(target_dev)
|
|
|
|
|
|
criterion = torch.nn.L1Loss()
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=float(args.lr)/args.bs)
|
|
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
|
optimizer,
|
|
total_steps=int(args.steps/args.bs),
|
|
max_lr=float(args.lr)/args.bs,
|
|
pct_start=0.015,
|
|
final_div_factor=2500,
|
|
)
|
|
|
|
progress = tqdm(total=args.steps)
|
|
|
|
while progress.n < args.steps:
|
|
for src, dst in loader:
|
|
with torch.cuda.amp.autocast():
|
|
y_pred = model(src)
|
|
loss = criterion(y_pred, dst)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
|
|
progress.update(args.bs)
|
|
if progress.n % (1000 + 1000%args.bs) == 0:
|
|
eval_model(progress.n, model, criterion, scheduler, eval_src, eval_dst)
|
|
if progress.n % (args.save + args.save%args.bs) == 0:
|
|
save_model(progress.n, model, args.ver, args.fac, eval_src)
|
|
if progress.n >= args.steps:
|
|
break
|
|
progress.close()
|
|
|
|
|
|
eval_model(args.steps, model, criterion, scheduler, eval_src, eval_dst)
|
|
save_model(args.steps, model, args.ver, args.fac, eval_src)
|
|
log.close()
|
|
|