Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import glob | |
import os | |
import os.path as osp | |
import warnings | |
import mmcv | |
import torch | |
from mmcv.utils import TORCH_VERSION, digit_version, print_log | |
def find_latest_checkpoint(path, suffix='pth'): | |
"""Find the latest checkpoint from the working directory. | |
Args: | |
path(str): The path to find checkpoints. | |
suffix(str): File extension. | |
Defaults to pth. | |
Returns: | |
latest_path(str | None): File path of the latest checkpoint. | |
References: | |
.. [1] https://github.com/microsoft/SoftTeacher | |
/blob/main/ssod/utils/patch.py | |
""" | |
if not osp.exists(path): | |
warnings.warn('The path of checkpoints does not exist.') | |
return None | |
if osp.exists(osp.join(path, f'latest.{suffix}')): | |
return osp.join(path, f'latest.{suffix}') | |
checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) | |
if len(checkpoints) == 0: | |
warnings.warn('There are no checkpoints in the path.') | |
return None | |
latest = -1 | |
latest_path = None | |
for checkpoint in checkpoints: | |
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) | |
if count > latest: | |
latest = count | |
latest_path = checkpoint | |
return latest_path | |
def update_data_root(cfg, logger=None): | |
"""Update data root according to env MMDET_DATASETS. | |
If set env MMDET_DATASETS, update cfg.data_root according to | |
MMDET_DATASETS. Otherwise, using cfg.data_root as default. | |
Args: | |
cfg (mmcv.Config): The model config need to modify | |
logger (logging.Logger | str | None): the way to print msg | |
""" | |
assert isinstance(cfg, mmcv.Config), \ | |
f'cfg got wrong type: {type(cfg)}, expected mmcv.Config' | |
if 'MMDET_DATASETS' in os.environ: | |
dst_root = os.environ['MMDET_DATASETS'] | |
print_log(f'MMDET_DATASETS has been set to be {dst_root}.' | |
f'Using {dst_root} as data root.') | |
else: | |
return | |
assert isinstance(cfg, mmcv.Config), \ | |
f'cfg got wrong type: {type(cfg)}, expected mmcv.Config' | |
def update(cfg, src_str, dst_str): | |
for k, v in cfg.items(): | |
if isinstance(v, mmcv.ConfigDict): | |
update(cfg[k], src_str, dst_str) | |
if isinstance(v, str) and src_str in v: | |
cfg[k] = v.replace(src_str, dst_str) | |
update(cfg.data, cfg.data_root, dst_root) | |
cfg.data_root = dst_root | |
_torch_version_div_indexing = ( | |
'parrots' not in TORCH_VERSION | |
and digit_version(TORCH_VERSION) >= digit_version('1.8')) | |
def floordiv(dividend, divisor, rounding_mode='trunc'): | |
if _torch_version_div_indexing: | |
return torch.div(dividend, divisor, rounding_mode=rounding_mode) | |
else: | |
return dividend // divisor | |