# Copyright (c) OpenMMLab. All rights reserved. import glob import os import os.path as osp import warnings import mmcv import torch from mmcv.utils import TORCH_VERSION, digit_version, print_log def find_latest_checkpoint(path, suffix='pth'): """Find the latest checkpoint from the working directory. Args: path(str): The path to find checkpoints. suffix(str): File extension. Defaults to pth. Returns: latest_path(str | None): File path of the latest checkpoint. References: .. [1] https://github.com/microsoft/SoftTeacher /blob/main/ssod/utils/patch.py """ if not osp.exists(path): warnings.warn('The path of checkpoints does not exist.') return None if osp.exists(osp.join(path, f'latest.{suffix}')): return osp.join(path, f'latest.{suffix}') checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) if len(checkpoints) == 0: warnings.warn('There are no checkpoints in the path.') return None latest = -1 latest_path = None for checkpoint in checkpoints: count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) if count > latest: latest = count latest_path = checkpoint return latest_path def update_data_root(cfg, logger=None): """Update data root according to env MMDET_DATASETS. If set env MMDET_DATASETS, update cfg.data_root according to MMDET_DATASETS. Otherwise, using cfg.data_root as default. Args: cfg (mmcv.Config): The model config need to modify logger (logging.Logger | str | None): the way to print msg """ assert isinstance(cfg, mmcv.Config), \ f'cfg got wrong type: {type(cfg)}, expected mmcv.Config' if 'MMDET_DATASETS' in os.environ: dst_root = os.environ['MMDET_DATASETS'] print_log(f'MMDET_DATASETS has been set to be {dst_root}.' f'Using {dst_root} as data root.') else: return assert isinstance(cfg, mmcv.Config), \ f'cfg got wrong type: {type(cfg)}, expected mmcv.Config' def update(cfg, src_str, dst_str): for k, v in cfg.items(): if isinstance(v, mmcv.ConfigDict): update(cfg[k], src_str, dst_str) if isinstance(v, str) and src_str in v: cfg[k] = v.replace(src_str, dst_str) update(cfg.data, cfg.data_root, dst_root) cfg.data_root = dst_root _torch_version_div_indexing = ( 'parrots' not in TORCH_VERSION and digit_version(TORCH_VERSION) >= digit_version('1.8')) def floordiv(dividend, divisor, rounding_mode='trunc'): if _torch_version_div_indexing: return torch.div(dividend, divisor, rounding_mode=rounding_mode) else: return dividend // divisor