Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import torch | |
from .lr_scheduler import WarmupMultiStepLR | |
def make_optimizer(cfg, model): | |
params = [] | |
for key, value in model.named_parameters(): | |
if not value.requires_grad: | |
continue | |
lr = cfg.SOLVER.BASE_LR | |
weight_decay = cfg.SOLVER.WEIGHT_DECAY | |
if "bias" in key: | |
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR | |
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS | |
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] | |
if cfg.SOLVER.USE_ADAM: | |
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
else: | |
optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM) | |
return optimizer | |
def make_lr_scheduler(cfg, optimizer): | |
return WarmupMultiStepLR( | |
optimizer, | |
cfg.SOLVER.STEPS, | |
cfg.SOLVER.GAMMA, | |
warmup_factor=cfg.SOLVER.WARMUP_FACTOR, | |
warmup_iters=cfg.SOLVER.WARMUP_ITERS, | |
warmup_method=cfg.SOLVER.WARMUP_METHOD, | |
pow_schedule_mode = cfg.SOLVER.POW_SCHEDULE, | |
max_iter = cfg.SOLVER.MAX_ITER, | |
) | |