import os import torch from tools.utils.logging import get_logger def save_ckpt( model, cfg, optimizer, lr_scheduler, epoch, global_step, metrics, is_best=False, logger=None, prefix=None, ): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar' """ if logger is None: logger = get_logger() if prefix is None: if is_best: save_path = os.path.join(cfg["Global"]["output_dir"], "best.pth") else: save_path = os.path.join(cfg["Global"]["output_dir"], "latest.pth") else: save_path = os.path.join(cfg["Global"]["output_dir"], prefix + ".pth") state_dict = model.module.state_dict() if cfg["Global"]["distributed"] else model.state_dict() state = { "epoch": epoch, "global_step": global_step, "state_dict": state_dict, "optimizer": None if is_best else optimizer.state_dict(), "scheduler": None if is_best else lr_scheduler.state_dict(), "config": cfg, "metrics": metrics, } torch.save(state, save_path) logger.info(f"save ckpt to {save_path}") def load_ckpt(model, cfg, optimizer=None, lr_scheduler=None, logger=None): """ Resume from saved checkpoints :param checkpoint_path: Checkpoint path to be resumed """ if logger is None: logger = get_logger() checkpoints = cfg["Global"].get("checkpoints") pretrained_model = cfg["Global"].get("pretrained_model") status = {} if checkpoints and os.path.exists(checkpoints): checkpoint = torch.load(checkpoints, map_location=torch.device("cpu")) model.load_state_dict(checkpoint["state_dict"], strict=True) if optimizer is not None: optimizer.load_state_dict(checkpoint["optimizer"]) if lr_scheduler is not None: lr_scheduler.load_state_dict(checkpoint["scheduler"]) logger.info(f"resume from checkpoint {checkpoints} (epoch {checkpoint['epoch']})") status["global_step"] = checkpoint["global_step"] status["epoch"] = checkpoint["epoch"] + 1 status["metrics"] = checkpoint["metrics"] elif pretrained_model and os.path.exists(pretrained_model): load_pretrained_params(model, pretrained_model, logger) logger.info(f"finetune from checkpoint {pretrained_model}") else: logger.info("train from scratch") return status def load_pretrained_params(model, pretrained_model, logger): checkpoint = torch.load(pretrained_model, map_location=torch.device("cpu")) model.load_state_dict(checkpoint["state_dict"], strict=False) for name in model.state_dict().keys(): if name not in checkpoint["state_dict"]: logger.info(f"{name} is not in pretrained model")