Vincentqyw
fix: roma
c74a070
raw
history blame
664 Bytes
"""
This file implements different learning rate schedulers
"""
import torch
def get_lr_scheduler(lr_decay, lr_decay_cfg, optimizer):
"""Get the learning rate scheduler according to the config."""
# If no lr_decay is specified => return None
if (lr_decay == False) or (lr_decay_cfg is None):
schduler = None
# Exponential decay
elif (lr_decay == True) and (lr_decay_cfg["policy"] == "exp"):
schduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer, gamma=lr_decay_cfg["gamma"]
)
# Unknown policy
else:
raise ValueError("[Error] Unknow learning rate decay policy!")
return schduler