Spaces:
Configuration error
Configuration error
from torch.optim import lr_scheduler | |
from models.backbone.skip import skip | |
def get_scheduler(optimizer, opt): | |
if opt.lr_policy == "linear": | |
def lambda_rule(epoch): | |
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) | |
return lr_l | |
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) | |
elif opt.lr_policy == "step": | |
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) | |
elif opt.lr_policy == "plateau": | |
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, threshold=0.01, patience=5) | |
elif opt.lr_policy == "cosine": | |
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) | |
else: | |
return NotImplementedError("learning rate policy [%s] is not implemented", opt.lr_policy) | |
return scheduler | |
def define_G(cfg): | |
netG = skip( | |
3, | |
4, | |
num_channels_down=[cfg["skip_n33d"]] * cfg["num_scales"] | |
if isinstance(cfg["skip_n33d"], int) | |
else cfg["skip_n33d"], | |
num_channels_up=[cfg["skip_n33u"]] * cfg["num_scales"] | |
if isinstance(cfg["skip_n33u"], int) | |
else cfg["skip_n33u"], | |
num_channels_skip=[cfg["skip_n11"]] * cfg["num_scales"] | |
if isinstance(cfg["skip_n11"], int) | |
else cfg["skip_n11"], | |
need_bias=True, | |
) | |
return netG | |