zhigangjiang's picture
no message
88b0dcb
raw
history blame
990 Bytes
"""
@Date: 2021/07/18
@description:
"""
from torch import optim as optim
def build_optimizer(config, model, logger):
name = config.TRAIN.OPTIMIZER.NAME.lower()
optimizer = None
if name == 'sgd':
optimizer = optim.SGD(model.parameters(), momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
elif name == 'adamw':
optimizer = optim.AdamW(model.parameters(), eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
elif name == 'adam':
optimizer = optim.Adam(model.parameters(), eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
logger.info(f"Build optimizer: {name}, lr:{config.TRAIN.BASE_LR}")
return optimizer