Spaces:
Runtime error
Runtime error
File size: 2,433 Bytes
29a229f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import os
from typing import Dict
from yacs.config import CfgNode as CN
def to_lower(x: Dict) -> Dict:
"""
Convert all dictionary keys to lowercase
Args:
x (dict): Input dictionary
Returns:
dict: Output dictionary with all keys converted to lowercase
"""
return {k.lower(): v for k, v in x.items()}
_C = CN(new_allowed=True)
_C.GENERAL = CN(new_allowed=True)
_C.GENERAL.RESUME = True
_C.GENERAL.TIME_TO_RUN = 3300
_C.GENERAL.VAL_STEPS = 100
_C.GENERAL.LOG_STEPS = 100
_C.GENERAL.CHECKPOINT_STEPS = 20000
_C.GENERAL.CHECKPOINT_DIR = "checkpoints"
_C.GENERAL.SUMMARY_DIR = "tensorboard"
_C.GENERAL.NUM_GPUS = 1
_C.GENERAL.NUM_WORKERS = 4
_C.GENERAL.MIXED_PRECISION = True
_C.GENERAL.ALLOW_CUDA = True
_C.GENERAL.PIN_MEMORY = False
_C.GENERAL.DISTRIBUTED = False
_C.GENERAL.LOCAL_RANK = 0
_C.GENERAL.USE_SYNCBN = False
_C.GENERAL.WORLD_SIZE = 1
_C.TRAIN = CN(new_allowed=True)
_C.TRAIN.NUM_EPOCHS = 100
_C.TRAIN.BATCH_SIZE = 32
_C.TRAIN.SHUFFLE = True
_C.TRAIN.WARMUP = False
_C.TRAIN.NORMALIZE_PER_IMAGE = False
_C.TRAIN.CLIP_GRAD = False
_C.TRAIN.CLIP_GRAD_VALUE = 1.0
_C.LOSS_WEIGHTS = CN(new_allowed=True)
_C.DATASETS = CN(new_allowed=True)
_C.MODEL = CN(new_allowed=True)
_C.MODEL.IMAGE_SIZE = 224
_C.EXTRA = CN(new_allowed=True)
_C.EXTRA.FOCAL_LENGTH = 5000
_C.DATASETS.CONFIG = CN(new_allowed=True)
_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
_C.DATASETS.CONFIG.ROT_FACTOR = 30
_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
_C.DATASETS.CONFIG.COLOR_SCALE = 0.2
_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
_C.DATASETS.CONFIG.DO_FLIP = True
_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
def default_config() -> CN:
"""
Get a yacs CfgNode object with the default config values.
"""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return _C.clone()
def get_config(config_file: str, merge: bool = True) -> CN:
"""
Read a config file and optionally merge it with the default config file.
Args:
config_file (str): Path to config file.
merge (bool): Whether to merge with the default config or not.
Returns:
CfgNode: Config as a yacs CfgNode object.
"""
if merge:
cfg = default_config()
else:
cfg = CN(new_allowed=True)
cfg.merge_from_file(config_file)
cfg.freeze()
return cfg
|