PKaushik's picture
commit
2c03f99
raw
history blame
1.5 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import os
import math
import torch
import torch.nn as nn
def build_optimizer(cfg, model):
""" Build optimizer from cfg file."""
g_bnw, g_w, g_b = [], [], []
for v in model.modules():
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
g_b.append(v.bias)
if isinstance(v, nn.BatchNorm2d):
g_bnw.append(v.weight)
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
g_w.append(v.weight)
assert cfg.solver.optim == 'SGD' or 'Adam', 'ERROR: unknown optimizer, use SGD defaulted'
if cfg.solver.optim == 'SGD':
optimizer = torch.optim.SGD(g_bnw, lr=cfg.solver.lr0, momentum=cfg.solver.momentum, nesterov=True)
elif cfg.solver.optim == 'Adam':
optimizer = torch.optim.Adam(g_bnw, lr=cfg.solver.lr0, betas=(cfg.solver.momentum, 0.999))
optimizer.add_param_group({'params': g_w, 'weight_decay': cfg.solver.weight_decay})
optimizer.add_param_group({'params': g_b})
del g_bnw, g_w, g_b
return optimizer
def build_lr_scheduler(cfg, optimizer, epochs):
"""Build learning rate scheduler from cfg file."""
if cfg.solver.lr_scheduler == 'Cosine':
lf = lambda x: ((1 - math.cos(x * math.pi / epochs)) / 2) * (cfg.solver.lrf - 1) + 1
else:
LOGGER.error('unknown lr scheduler, use Cosine defaulted')
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
return scheduler, lf