#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
@author: Tu Bui @University of Surrey | |
""" | |
class SimpleLossWeightScheduler(object): | |
def __init__(self, simple_loss_weight_max=10., wait_steps=50000, ramp=100000) -> None: | |
self.simple_loss_weight_max = simple_loss_weight_max | |
self.wait_steps = wait_steps | |
self.ramp = ramp | |
def __call__(self, step): | |
max_weight = self.simple_loss_weight_max - 1 | |
w = 1 + min(max_weight, max(0., max_weight*(step - self.wait_steps)/self.ramp)) | |
return w | |