LN3Diff / scripts /vit_triplane_cldm_train.py
NIRVANALAN
release file
87c126b
raw
history blame
12.7 kB
"""
Train a diffusion model on images.
"""
import json
import sys
import os
sys.path.append('.')
# from dnnlib import EasyDict
import traceback
import torch as th
import torch.multiprocessing as mp
import torch.distributed as dist
import numpy as np
import argparse
import dnnlib
from guided_diffusion import dist_util, logger
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
args_to_dict,
add_dict_to_argparser,
continuous_diffusion_defaults,
model_and_diffusion_defaults,
create_model_and_diffusion,
)
from guided_diffusion.continuous_diffusion import make_diffusion as make_sde_diffusion
import nsr
import nsr.lsgm
# from nsr.train_util_diffusion import TrainLoop3DDiffusion as TrainLoop
from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, rendering_options_defaults, eg3d_options_default
from datasets.shapenet import load_data, load_eval_data, load_memory_data
from nsr.losses.builder import E3DGELossClass
from utils.torch_utils import legacy, misc
from torch.utils.data import Subset
from pdb import set_trace as st
from dnnlib.util import EasyDict, InfiniteSampler
# from .vit_triplane_train_FFHQ import init_dataset_kwargs
from datasets.eg3d_dataset import init_dataset_kwargs
# from torch.utils.tensorboard import SummaryWriter
SEED = 0
def training_loop(args):
# def training_loop(args):
logger.log("dist setup...")
th.cuda.set_device(
args.local_rank) # set this line to avoid extra memory on rank 0
th.cuda.empty_cache()
th.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
dist_util.setup_dist(args)
# st() # mark
# logger.configure(dir=args.logdir, format_strs=["tensorboard", "csv"])
logger.configure(dir=args.logdir)
logger.log("creating ViT encoder and NSR decoder...")
# st() # mark
device = dist_util.dev()
args.img_size = [args.image_size_encoder]
logger.log("creating model and diffusion...")
# * set denoise model args
if args.denoise_in_channels == -1:
args.diffusion_input_size = args.image_size_encoder
args.denoise_in_channels = args.out_chans
args.denoise_out_channels = args.out_chans
else:
assert args.denoise_out_channels != -1
# args.image_size = args.image_size_encoder # 224, follow the triplane size
# if args.diffusion_input_size == -1:
# else:
# args.image_size = args.diffusion_input_size
denoise_model, diffusion = create_model_and_diffusion(
**args_to_dict(args,
model_and_diffusion_defaults().keys()))
denoise_model.to(dist_util.dev())
denoise_model.train()
opts = eg3d_options_default()
if args.sr_training:
args.sr_kwargs = dnnlib.EasyDict(
channel_base=opts.cbase,
channel_max=opts.cmax,
fused_modconv_default='inference_only',
use_noise=True
) # ! close noise injection? since noise_mode='none' in eg3d
logger.log("creating encoder and NSR decoder...")
auto_encoder = create_3DAE_model(
**args_to_dict(args,
encoder_and_nsr_defaults().keys()))
auto_encoder.to(device)
auto_encoder.eval()
# * load G_ema modules into autoencoder
# * clone G_ema.decoder to auto_encoder triplane
# logger.log("AE triplane decoder reuses G_ema decoder...")
# auto_encoder.decoder.register_buffer('w_avg', G_ema.backbone.mapping.w_avg)
# auto_encoder.decoder.triplane_decoder.decoder.load_state_dict( # type: ignore
# G_ema.decoder.state_dict()) # type: ignore
# set grad=False in this manner suppresses the DDP forward no grad error.
logger.log("freeze triplane decoder...")
for param in auto_encoder.decoder.triplane_decoder.parameters(
): # type: ignore
# for param in auto_encoder.decoder.triplane_decoder.decoder.parameters(): # type: ignore
param.requires_grad_(False)
# if args.sr_training:
# logger.log("AE triplane decoder reuses G_ema SR module...")
# # auto_encoder.decoder.triplane_decoder.superresolution.load_state_dict( # type: ignore
# # G_ema.superresolution.state_dict()) # type: ignore
# # set grad=False in this manner suppresses the DDP forward no grad error.
# logger.log("freeze SR module...")
# for param in auto_encoder.decoder.superresolution.parameters(): # type: ignore
# param.requires_grad_(False)
# # del G_ema
# th.cuda.empty_cache()
if args.cfg in ('afhq', 'ffhq'):
if args.sr_training:
logger.log("AE triplane decoder reuses G_ema SR module...")
auto_encoder.decoder.triplane_decoder.superresolution.load_state_dict( # type: ignore
G_ema.superresolution.state_dict()) # type: ignore
# set grad=False in this manner suppresses the DDP forward no grad error.
for param in auto_encoder.decoder.triplane_decoder.superresolution.parameters(
): # type: ignore
param.requires_grad_(False)
# ! load data
logger.log("creating eg3d data loader...")
training_set_kwargs, dataset_name = init_dataset_kwargs(
data=args.data_dir,
class_name='datasets.eg3d_dataset.ImageFolderDataset'
) # only load pose here
# if args.cond and not training_set_kwargs.use_labels:
# raise Exception('check here')
# training_set_kwargs.use_labels = args.cond
training_set_kwargs.use_labels = True
training_set_kwargs.xflip = True
training_set_kwargs.random_seed = SEED
# desc = f'{args.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'
# * construct ffhq/afhq dataset
training_set = dnnlib.util.construct_class_by_name(
**training_set_kwargs) # subclass of training.dataset.Dataset
training_set = dnnlib.util.construct_class_by_name(
**training_set_kwargs) # subclass of training.dataset.Dataset
training_set_sampler = InfiniteSampler(
dataset=training_set,
rank=dist_util.get_rank(),
num_replicas=dist_util.get_world_size(),
seed=SEED)
data = iter(
th.utils.data.DataLoader(
dataset=training_set,
sampler=training_set_sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.num_workers,
))
# prefetch_factor=2))
eval_data = th.utils.data.DataLoader(dataset=Subset(
training_set, np.arange(10)),
batch_size=args.eval_batch_size,
num_workers=1)
else:
logger.log("creating data loader...")
# TODO, load shapenet data
# data = load_data(
# st() mark
if args.overfitting:
logger.log("create overfitting memory dataset")
data = load_memory_data(
file_path=args.eval_data_dir,
batch_size=args.batch_size,
reso=args.image_size,
reso_encoder=args.image_size_encoder, # 224 -> 128
num_workers=args.num_workers,
load_depth=True # for evaluation
)
else:
logger.log("create all instances dataset")
# st() mark
data = load_data(
file_path=args.data_dir,
batch_size=args.batch_size,
reso=args.image_size,
reso_encoder=args.image_size_encoder, # 224 -> 128
num_workers=args.num_workers,
load_depth=True,
preprocess=auto_encoder.preprocess, # clip
dataset_size=args.dataset_size,
# load_depth=True # for evaluation
)
# st() mark
eval_data = load_eval_data(
file_path=args.eval_data_dir,
batch_size=args.eval_batch_size,
reso=args.image_size,
reso_encoder=args.image_size_encoder, # 224 -> 128
num_workers=args.num_workers,
load_depth=True # for evaluation
)
# let all processes sync up before starting with a new epoch of training
if dist_util.get_rank() == 0:
with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
json.dump(vars(args), f, indent=2)
args.schedule_sampler = create_named_schedule_sampler(
args.schedule_sampler, diffusion)
opt = dnnlib.EasyDict(args_to_dict(args, loss_defaults().keys()))
loss_class = E3DGELossClass(device, opt).to(device)
logger.log("training...")
TrainLoop = {
'adm': nsr.TrainLoop3DDiffusion,
'dit': nsr.TrainLoop3DDiffusionDiT,
'ssd': nsr.TrainLoop3DDiffusionSingleStage,
# 'ssd_cvD': nsr.TrainLoop3DDiffusionSingleStagecvD,
'ssd_cvD_sds': nsr.TrainLoop3DDiffusionSingleStagecvDSDS,
'ssd_cvd_sds_no_separate_sds_step':
nsr.TrainLoop3DDiffusionSingleStagecvDSDS_sdswithrec,
'vpsde_lsgm_noD': nsr.lsgm.TrainLoop3DDiffusionLSGM_noD, # use vpsde
# 'vpsde_lsgm': nsr.TrainLoop3DDiffusionLSGM, # use vpsde
# 'vpsde': nsr.TrainLoop3DDiffusion_vpsde,
}[args.trainer_name]
if 'vpsde' in args.trainer_name:
sde_diffusion = make_sde_diffusion(
dnnlib.EasyDict(
args_to_dict(args,
continuous_diffusion_defaults().keys())))
assert args.mixed_prediction, 'enable mixed_prediction by default'
logger.log('create VPSDE diffusion.')
else:
sde_diffusion = None
dist_util.synchronize()
TrainLoop(rec_model=auto_encoder,
denoise_model=denoise_model,
diffusion=diffusion,
sde_diffusion=sde_diffusion,
loss_class=loss_class,
data=data,
eval_data=eval_data,
**vars(args)).run_loop()
def create_argparser(**kwargs):
# defaults.update(model_and_diffusion_defaults())
defaults = dict(
dataset_size=-1,
diffusion_input_size=-1,
trainer_name='adm',
use_amp=False,
triplane_scaling_divider=1.0, # divide by this value
overfitting=False,
num_workers=4,
image_size=128,
image_size_encoder=224,
iterations=150000,
schedule_sampler="uniform",
anneal_lr=False,
lr=5e-5,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=1,
eval_batch_size=12,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
log_interval=50,
eval_interval=2500,
save_interval=10000,
resume_checkpoint="",
resume_checkpoint_EG3D="",
use_fp16=False,
fp16_scale_growth=1e-3,
data_dir="",
eval_data_dir="",
# load_depth=False, # TODO
logdir="/mnt/lustre/yslan/logs/nips23/",
load_submodule_name='', # for loading pretrained auto_encoder model
ignore_resume_opt=False,
# freeze_ae=False,
denoised_ae=True,
)
defaults.update(model_and_diffusion_defaults())
defaults.update(continuous_diffusion_defaults())
defaults.update(encoder_and_nsr_defaults()) # type: ignore
defaults.update(loss_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
# os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
# os.environ["NCCL_DEBUG"] = "INFO"
os.environ[
"TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
args = create_argparser().parse_args()
args.local_rank = int(os.environ["LOCAL_RANK"])
args.gpus = th.cuda.device_count()
# opts = dnnlib.EasyDict(vars(args)) # compatiable with triplane original settings
# opts = args
args.rendering_kwargs = rendering_options_defaults(args)
# Launch processes.
logger.log('Launching processes...')
logger.log('Available devices ', th.cuda.device_count())
logger.log('Current cuda device ', th.cuda.current_device())
# logger.log('GPU Device name:', th.cuda.get_device_name(th.cuda.current_device()))
try:
training_loop(args)
# except KeyboardInterrupt as e:
except Exception as e:
# print(e)
traceback.print_exc()
dist_util.cleanup() # clean port and socket when ctrl+c