from pytorch_lightning import Trainer from models import vae_models from config import config from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.loggers import TensorBoardLogger import os os.environ['KMP_DUPLICATE_LIB_OK']='True' def make_model(config): model_type = config.model_type model_config = config.model_config if model_type not in vae_models.keys(): raise NotImplementedError("Model Architecture not implemented") else: return vae_models[model_type](**model_config.dict()) if __name__ == "__main__": model = make_model(config) train_config = config.train_config logger = TensorBoardLogger(**config.log_config.dict()) trainer = Trainer(**train_config.dict(), logger=logger, callbacks=LearningRateMonitor()) if train_config.auto_lr_find: lr_finder = trainer.tuner.lr_find(model) new_lr = lr_finder.suggestion() print("Learning Rate Chosen:", new_lr) model.lr = new_lr trainer.fit(model) else: trainer.fit(model) if not os.path.isdir("./saved_models"): os.mkdir("./saved_models") trainer.save_checkpoint( f"saved_models/{config.model_type}_alpha_{config.model_config.alpha}_dim_{config.model_config.hidden_size}.ckpt")