Spaces:
Running
on
Zero
Running
on
Zero
# Author: Haohe Liu | |
# Email: haoheliu@gmail.com | |
# Date: 11 Feb 2023 | |
import sys | |
sys.path.append("src") | |
import os | |
import wandb | |
import argparse | |
import yaml | |
import torch | |
from pytorch_lightning.strategies.ddp import DDPStrategy | |
from qa_mdt.audioldm_train.utilities.data.dataset import AudioDataset | |
from torch.utils.data import DataLoader | |
from pytorch_lightning.loggers import WandbLogger | |
from pytorch_lightning import Trainer | |
from qa_mdt.audioldm_train.modules.latent_encoder.autoencoder import AutoencoderKL | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from qa_mdt.audioldm_train.utilities.tools import get_restore_step | |
def listdir_nohidden(path): | |
for f in os.listdir(path): | |
if not f.startswith("."): | |
yield f | |
def main(configs, exp_group_name, exp_name): | |
if "precision" in configs.keys(): | |
torch.set_float32_matmul_precision(configs["precision"]) | |
batch_size = config_yaml["model"]["params"]["batchsize"] | |
log_path = config_yaml["log_directory"] | |
if "dataloader_add_ons" in configs["data"].keys(): | |
dataloader_add_ons = configs["data"]["dataloader_add_ons"] | |
else: | |
dataloader_add_ons = [] | |
dataset = AudioDataset(config_yaml, split="train", add_ons=dataloader_add_ons) | |
loader = DataLoader( | |
dataset, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=True | |
) | |
print( | |
"The length of the dataset is %s, the length of the dataloader is %s, the batchsize is %s" | |
% (len(dataset), len(loader), batch_size) | |
) | |
val_dataset = AudioDataset(config_yaml, split="val", add_ons=dataloader_add_ons) | |
val_loader = DataLoader( | |
val_dataset, | |
batch_size=batch_size, | |
num_workers=8, | |
shuffle=True, | |
) | |
model = AutoencoderKL( | |
ddconfig=config_yaml["model"]["params"]["ddconfig"], | |
lossconfig=config_yaml["model"]["params"]["lossconfig"], | |
embed_dim=config_yaml["model"]["params"]["embed_dim"], | |
image_key=config_yaml["model"]["params"]["image_key"], | |
base_learning_rate=config_yaml["model"]["base_learning_rate"], | |
subband=config_yaml["model"]["params"]["subband"], | |
sampling_rate=config_yaml["preprocessing"]["audio"]["sampling_rate"], | |
) | |
try: | |
config_reload_from_ckpt = configs["reload_from_ckpt"] | |
except: | |
config_reload_from_ckpt = None | |
checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") | |
checkpoint_callback = ModelCheckpoint( | |
dirpath=checkpoint_path, | |
monitor="global_step", | |
mode="max", | |
filename="checkpoint-{global_step:.0f}", | |
every_n_train_steps=5000, | |
save_top_k=config_yaml["step"]["save_top_k"], | |
auto_insert_metric_name=False, | |
save_last=True, | |
) | |
wandb_path = os.path.join(log_path, exp_group_name, exp_name) | |
model.set_log_dir(log_path, exp_group_name, exp_name) | |
os.makedirs(checkpoint_path, exist_ok=True) | |
if len(os.listdir(checkpoint_path)) > 0: | |
print("Load checkpoint from path: %s" % checkpoint_path) | |
restore_step, n_step = get_restore_step(checkpoint_path) | |
resume_from_checkpoint = os.path.join(checkpoint_path, restore_step) | |
print("Resume from checkpoint", resume_from_checkpoint) | |
elif config_reload_from_ckpt is not None: | |
resume_from_checkpoint = config_reload_from_ckpt | |
print("Reload ckpt specified in the config file %s" % resume_from_checkpoint) | |
else: | |
print("Train from scratch") | |
resume_from_checkpoint = None | |
devices = torch.cuda.device_count() | |
wandb_logger = WandbLogger( | |
save_dir=wandb_path, | |
project=config_yaml["project"], | |
config=config_yaml, | |
name="%s/%s" % (exp_group_name, exp_name), | |
) | |
trainer = Trainer( | |
accelerator="gpu", | |
devices=devices, | |
logger=wandb_logger, | |
limit_val_batches=100, | |
callbacks=[checkpoint_callback], | |
strategy=DDPStrategy(find_unused_parameters=True), | |
val_check_interval=2000, | |
) | |
# TRAINING | |
trainer.fit(model, loader, val_loader, ckpt_path=resume_from_checkpoint) | |
# EVALUTION | |
# trainer.test(model, test_loader, ckpt_path=resume_from_checkpoint) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-c", | |
"--autoencoder_config", | |
type=str, | |
required=True, | |
help="path to autoencoder config .yam", | |
) | |
args = parser.parse_args() | |
config_yaml = args.autoencoder_config | |
exp_name = os.path.basename(config_yaml.split(".")[0]) | |
exp_group_name = os.path.basename(os.path.dirname(config_yaml)) | |
config_yaml = os.path.join(config_yaml) | |
config_yaml = yaml.load(open(config_yaml, "r"), Loader=yaml.FullLoader) | |
main(config_yaml, exp_group_name, exp_name) | |