|
""" |
|
@Date: 2021/07/18 |
|
@description: |
|
""" |
|
from torch import optim as optim |
|
|
|
|
|
def build_optimizer(config, model, logger): |
|
name = config.TRAIN.OPTIMIZER.NAME.lower() |
|
|
|
optimizer = None |
|
if name == 'sgd': |
|
optimizer = optim.SGD(model.parameters(), momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, |
|
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) |
|
elif name == 'adamw': |
|
optimizer = optim.AdamW(model.parameters(), eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, |
|
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) |
|
elif name == 'adam': |
|
optimizer = optim.Adam(model.parameters(), eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, |
|
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) |
|
|
|
logger.info(f"Build optimizer: {name}, lr:{config.TRAIN.BASE_LR}") |
|
|
|
return optimizer |
|
|