|
""" |
|
@Date: 2021/07/17 |
|
@description: |
|
""" |
|
import os |
|
import logging |
|
from yacs.config import CfgNode as CN |
|
|
|
_C = CN() |
|
_C.DEBUG = False |
|
_C.MODE = 'train' |
|
_C.VAL_NAME = 'val' |
|
_C.TAG = 'default' |
|
_C.COMMENT = 'add some comments to help you understand' |
|
_C.SHOW_BAR = True |
|
_C.SAVE_EVAL = False |
|
_C.MODEL = CN() |
|
_C.MODEL.NAME = 'model_name' |
|
_C.MODEL.SAVE_BEST = True |
|
_C.MODEL.SAVE_LAST = True |
|
_C.MODEL.ARGS = [] |
|
_C.MODEL.FINE_TUNE = [] |
|
|
|
|
|
|
|
|
|
_C.TRAIN = CN() |
|
_C.TRAIN.SCRATCH = False |
|
_C.TRAIN.START_EPOCH = 0 |
|
_C.TRAIN.EPOCHS = 300 |
|
_C.TRAIN.DETERMINISTIC = False |
|
_C.TRAIN.SAVE_FREQ = 5 |
|
|
|
_C.TRAIN.BASE_LR = 5e-4 |
|
|
|
_C.TRAIN.WARMUP_EPOCHS = 20 |
|
_C.TRAIN.WEIGHT_DECAY = 0 |
|
_C.TRAIN.WARMUP_LR = 5e-7 |
|
_C.TRAIN.MIN_LR = 5e-6 |
|
|
|
_C.TRAIN.CLIP_GRAD = 5.0 |
|
|
|
_C.TRAIN.RESUME_LAST = True |
|
|
|
|
|
_C.TRAIN.ACCUMULATION_STEPS = 0 |
|
|
|
|
|
_C.TRAIN.USE_CHECKPOINT = False |
|
|
|
_C.TRAIN.DEVICE = 'cuda' |
|
|
|
|
|
_C.TRAIN.LR_SCHEDULER = CN() |
|
_C.TRAIN.LR_SCHEDULER.NAME = '' |
|
_C.TRAIN.LR_SCHEDULER.ARGS = [] |
|
|
|
|
|
|
|
_C.TRAIN.OPTIMIZER = CN() |
|
_C.TRAIN.OPTIMIZER.NAME = 'adam' |
|
|
|
_C.TRAIN.OPTIMIZER.EPS = 1e-8 |
|
|
|
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) |
|
|
|
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 |
|
|
|
|
|
_C.TRAIN.CRITERION = CN() |
|
|
|
_C.TRAIN.CRITERION.BOUNDARY = CN() |
|
_C.TRAIN.CRITERION.BOUNDARY.NAME = 'boundary' |
|
_C.TRAIN.CRITERION.BOUNDARY.LOSS = 'BoundaryLoss' |
|
_C.TRAIN.CRITERION.BOUNDARY.WEIGHT = 0.0 |
|
_C.TRAIN.CRITERION.BOUNDARY.WEIGHTS = [] |
|
_C.TRAIN.CRITERION.BOUNDARY.NEED_ALL = True |
|
|
|
_C.TRAIN.CRITERION.LEDDepth = CN() |
|
_C.TRAIN.CRITERION.LEDDepth.NAME = 'led_depth' |
|
_C.TRAIN.CRITERION.LEDDepth.LOSS = 'LEDLoss' |
|
_C.TRAIN.CRITERION.LEDDepth.WEIGHT = 0.0 |
|
_C.TRAIN.CRITERION.LEDDepth.WEIGHTS = [] |
|
_C.TRAIN.CRITERION.LEDDepth.NEED_ALL = True |
|
|
|
_C.TRAIN.CRITERION.DEPTH = CN() |
|
_C.TRAIN.CRITERION.DEPTH.NAME = 'depth' |
|
_C.TRAIN.CRITERION.DEPTH.LOSS = 'L1Loss' |
|
_C.TRAIN.CRITERION.DEPTH.WEIGHT = 0.0 |
|
_C.TRAIN.CRITERION.DEPTH.WEIGHTS = [] |
|
_C.TRAIN.CRITERION.DEPTH.NEED_ALL = False |
|
|
|
_C.TRAIN.CRITERION.RATIO = CN() |
|
_C.TRAIN.CRITERION.RATIO.NAME = 'ratio' |
|
_C.TRAIN.CRITERION.RATIO.LOSS = 'L1Loss' |
|
_C.TRAIN.CRITERION.RATIO.WEIGHT = 0.0 |
|
_C.TRAIN.CRITERION.RATIO.WEIGHTS = [] |
|
_C.TRAIN.CRITERION.RATIO.NEED_ALL = False |
|
|
|
_C.TRAIN.CRITERION.GRAD = CN() |
|
_C.TRAIN.CRITERION.GRAD.NAME = 'grad' |
|
_C.TRAIN.CRITERION.GRAD.LOSS = 'GradLoss' |
|
_C.TRAIN.CRITERION.GRAD.WEIGHT = 0.0 |
|
_C.TRAIN.CRITERION.GRAD.WEIGHTS = [1.0, 1.0] |
|
_C.TRAIN.CRITERION.GRAD.NEED_ALL = True |
|
|
|
_C.TRAIN.CRITERION.OBJECT = CN() |
|
_C.TRAIN.CRITERION.OBJECT.NAME = 'object' |
|
_C.TRAIN.CRITERION.OBJECT.LOSS = 'ObjectLoss' |
|
_C.TRAIN.CRITERION.OBJECT.WEIGHT = 0.0 |
|
_C.TRAIN.CRITERION.OBJECT.WEIGHTS = [] |
|
_C.TRAIN.CRITERION.OBJECT.NEED_ALL = True |
|
|
|
_C.TRAIN.CRITERION.CHM = CN() |
|
_C.TRAIN.CRITERION.CHM.NAME = 'corner_heat_map' |
|
_C.TRAIN.CRITERION.CHM.LOSS = 'HeatmapLoss' |
|
_C.TRAIN.CRITERION.CHM.WEIGHT = 0.0 |
|
_C.TRAIN.CRITERION.CHM.WEIGHTS = [] |
|
_C.TRAIN.CRITERION.CHM.NEED_ALL = False |
|
|
|
_C.TRAIN.VIS_MERGE = True |
|
_C.TRAIN.VIS_WEIGHT = 1024 |
|
|
|
|
|
|
|
_C.CKPT = CN() |
|
_C.CKPT.PYTORCH = './' |
|
_C.CKPT.ROOT = "./checkpoints" |
|
_C.CKPT.DIR = os.path.join(_C.CKPT.ROOT, _C.MODEL.NAME, _C.TAG) |
|
_C.CKPT.RESULT_DIR = os.path.join(_C.CKPT.DIR, 'results', _C.MODE) |
|
|
|
_C.LOGGER = CN() |
|
_C.LOGGER.DIR = os.path.join(_C.CKPT.DIR, "logs") |
|
_C.LOGGER.LEVEL = logging.DEBUG |
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.AMP_OPT_LEVEL = 'O1' |
|
|
|
_C.OUTPUT = '' |
|
|
|
_C.TAG = 'default' |
|
|
|
_C.SAVE_FREQ = 1 |
|
|
|
_C.PRINT_FREQ = 10 |
|
|
|
_C.SEED = 0 |
|
|
|
_C.EVAL_MODE = False |
|
|
|
_C.THROUGHPUT_MODE = False |
|
|
|
|
|
|
|
|
|
_C.LOCAL_RANK = 0 |
|
_C.WORLD_SIZE = 0 |
|
|
|
|
|
|
|
|
|
_C.DATA = CN() |
|
|
|
_C.DATA.SUBSET = None |
|
|
|
_C.DATA.DATASET = 'mp3d' |
|
|
|
_C.DATA.DIR = '' |
|
|
|
_C.DATA.WALL_NUM = 0 |
|
|
|
_C.DATA.SHAPE = [512, 1024] |
|
|
|
_C.DATA.CAMERA_HEIGHT = 1.6 |
|
|
|
_C.DATA.PIN_MEMORY = True |
|
|
|
_C.DATA.FOR_TEST_INDEX = None |
|
|
|
|
|
_C.DATA.BATCH_SIZE = 8 |
|
|
|
_C.DATA.NUM_WORKERS = 8 |
|
|
|
|
|
_C.DATA.AUG = CN() |
|
|
|
_C.DATA.AUG.FLIP = True |
|
|
|
_C.DATA.AUG.STRETCH = True |
|
|
|
_C.DATA.AUG.ROTATE = True |
|
|
|
_C.DATA.AUG.GAMMA = True |
|
|
|
_C.DATA.KEYS = [] |
|
|
|
|
|
_C.EVAL = CN() |
|
_C.EVAL.POST_PROCESSING = None |
|
_C.EVAL.NEED_CPE = False |
|
_C.EVAL.NEED_F1 = False |
|
_C.EVAL.NEED_RMSE = False |
|
_C.EVAL.FORCE_CUBE = False |
|
|
|
|
|
def merge_from_file(cfg_path): |
|
config = _C.clone() |
|
config.merge_from_file(cfg_path) |
|
return config |
|
|
|
|
|
def get_config(args=None): |
|
config = _C.clone() |
|
if args: |
|
if 'cfg' in args and args.cfg: |
|
config.merge_from_file(args.cfg) |
|
|
|
if 'mode' in args and args.mode: |
|
config.MODE = args.mode |
|
|
|
if 'debug' in args and args.debug: |
|
config.DEBUG = args.debug |
|
|
|
if 'hidden_bar' in args and args.hidden_bar: |
|
config.SHOW_BAR = False |
|
|
|
if 'bs' in args and args.bs: |
|
config.DATA.BATCH_SIZE = args.bs |
|
|
|
if 'save_eval' in args and args.save_eval: |
|
config.SAVE_EVAL = True |
|
|
|
if 'val_name' in args and args.val_name: |
|
config.VAL_NAME = args.val_name |
|
|
|
if 'post_processing' in args and args.post_processing: |
|
config.EVAL.POST_PROCESSING = args.post_processing |
|
|
|
if 'need_cpe' in args and args.need_cpe: |
|
config.EVAL.NEED_CPE = args.need_cpe |
|
|
|
if 'need_f1' in args and args.need_f1: |
|
config.EVAL.NEED_F1 = args.need_f1 |
|
|
|
if 'need_rmse' in args and args.need_rmse: |
|
config.EVAL.NEED_RMSE = args.need_rmse |
|
|
|
if 'force_cube' in args and args.force_cube: |
|
config.EVAL.FORCE_CUBE = args.force_cube |
|
|
|
if 'wall_num' in args and args.wall_num: |
|
config.DATA.WALL_NUM = args.wall_num |
|
|
|
args = config.MODEL.ARGS[0] |
|
config.CKPT.DIR = os.path.join(config.CKPT.ROOT, f"{args['decoder_name']}_{args['output_name']}_Net", |
|
config.TAG, 'debug' if config.DEBUG else '') |
|
config.CKPT.RESULT_DIR = os.path.join(config.CKPT.DIR, 'results', config.MODE) |
|
config.LOGGER.DIR = os.path.join(config.CKPT.DIR, "logs") |
|
|
|
core_number = os.popen("grep 'physical id' /proc/cpuinfo | sort | uniq | wc -l").read() |
|
|
|
try: |
|
config.DATA.NUM_WORKERS = int(core_number) * 2 |
|
print(f"System core number: {config.DATA.NUM_WORKERS}") |
|
except ValueError: |
|
print(f"Can't get system core number, will use config: { config.DATA.NUM_WORKERS}") |
|
config.freeze() |
|
return config |
|
|
|
|
|
def get_rank_config(cfg, local_rank, world_size): |
|
local_rank = 0 if local_rank is None else local_rank |
|
config = cfg.clone() |
|
config.defrost() |
|
if world_size > 1: |
|
ids = config.TRAIN.DEVICE.split(':')[-1].split(',') if ':' in config.TRAIN.DEVICE else range(world_size) |
|
config.TRAIN.DEVICE = f'cuda:{ids[local_rank]}' |
|
|
|
config.LOCAL_RANK = local_rank |
|
config.WORLD_SIZE = world_size |
|
config.SEED = config.SEED + local_rank |
|
|
|
config.freeze() |
|
return config |
|
|