|
from .default_helper import deep_merge_dicts |
|
from easydict import EasyDict |
|
|
|
|
|
class Scheduler(object): |
|
""" |
|
Overview: |
|
Update learning parameters when the trueskill metrics has stopped improving. |
|
For example, models often benefits from reducing entropy weight once the learning process stagnates. |
|
This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs, |
|
the corresponding parameter is increased or decreased, which decides on the 'schedule_mode'. |
|
Arguments: |
|
- schedule_flag (:obj:`bool`): Indicates whether to use scheduler in training pipeline. |
|
Default: False |
|
- schedule_mode (:obj:`str`): One of 'reduce', 'add','multi','div'. The schecule_mode |
|
decides the way of updating the parameters. Default:'reduce'. |
|
- factor (:obj:`float`) : Amount (greater than 0) by which the parameter will be |
|
increased/decreased. Default: 0.05 |
|
- change_range (:obj:`list`): Indicates the minimum and maximum value |
|
the parameter can reach respectively. Default: [-1,1] |
|
- threshold (:obj:`float`): Threshold for measuring the new optimum, |
|
to only focus on significant changes. Default: 1e-4. |
|
- optimize_mode (:obj:`str`): One of 'min', 'max', which indicates the sign of |
|
optimization objective. Dynamic_threshold = last_metrics + threshold in `max` |
|
mode or last_metrics - threshold in `min` mode. Default: 'min' |
|
- patience (:obj:`int`): Number of epochs with no improvement after which |
|
the parameter will be updated. For example, if `patience = 2`, then we |
|
will ignore the first 2 epochs with no improvement, and will only update |
|
the parameter after the 3rd epoch if the metrics still hasn't improved then. |
|
Default: 10. |
|
- cooldown (:obj:`int`): Number of epochs to wait before resuming |
|
normal operation after the parameter has been updated. Default: 0. |
|
Interfaces: |
|
__init__, update_param, step |
|
Property: |
|
in_cooldown, is_better |
|
""" |
|
|
|
config = dict( |
|
schedule_flag=False, |
|
schedule_mode='reduce', |
|
factor=0.05, |
|
change_range=[-1, 1], |
|
threshold=1e-4, |
|
optimize_mode='min', |
|
patience=10, |
|
cooldown=0, |
|
) |
|
|
|
def __init__(self, merged_scheduler_config: EasyDict) -> None: |
|
""" |
|
Overview: |
|
Initialize the scheduler. |
|
Arguments: |
|
- merged_scheduler_config (:obj:`EasyDict`): the scheduler config, which merges the user |
|
config and defaul config |
|
""" |
|
|
|
schedule_mode = merged_scheduler_config.schedule_mode |
|
factor = merged_scheduler_config.factor |
|
change_range = merged_scheduler_config.change_range |
|
threshold = merged_scheduler_config.threshold |
|
optimize_mode = merged_scheduler_config.optimize_mode |
|
patience = merged_scheduler_config.patience |
|
cooldown = merged_scheduler_config.cooldown |
|
|
|
assert schedule_mode in [ |
|
'reduce', 'add', 'multi', 'div' |
|
], 'The schedule mode should be one of [\'reduce\', \'add\', \'multi\',\'div\']' |
|
self.schedule_mode = schedule_mode |
|
|
|
assert isinstance(factor, (float, int)), 'The factor should be a float/int number ' |
|
assert factor > 0, 'The factor should be greater than 0' |
|
self.factor = float(factor) |
|
|
|
assert isinstance(change_range, |
|
list) and len(change_range) == 2, 'The change_range should be a list with 2 float numbers' |
|
assert (isinstance(change_range[0], (float, int))) and ( |
|
isinstance(change_range[1], (float, int)) |
|
), 'The change_range should be a list with 2 float/int numbers' |
|
assert change_range[0] < change_range[1], 'The first num should be smaller than the second num' |
|
self.change_range = change_range |
|
|
|
assert isinstance(threshold, (float, int)), 'The threshold should be a float/int number' |
|
self.threshold = threshold |
|
|
|
assert optimize_mode in ['min', 'max'], 'The optimize_mode should be one of [\'min\', \'max\']' |
|
self.optimize_mode = optimize_mode |
|
|
|
assert isinstance(patience, int), 'The patience should be a integer greater than or equal to 0' |
|
assert patience >= 0, 'The patience should be a integer greater than or equal to 0' |
|
self.patience = patience |
|
|
|
assert isinstance(cooldown, int), 'The cooldown_counter should be a integer greater than or equal to 0' |
|
assert cooldown >= 0, 'The cooldown_counter should be a integer greater than or equal to 0' |
|
self.cooldown = cooldown |
|
self.cooldown_counter = cooldown |
|
|
|
self.last_metrics = None |
|
self.bad_epochs_num = 0 |
|
|
|
def step(self, metrics: float, param: float) -> float: |
|
""" |
|
Overview: |
|
Decides whether to update the scheduled parameter |
|
Args: |
|
- metrics (:obj:`float`): current input metrics |
|
- param (:obj:`float`): parameter need to be updated |
|
Returns: |
|
- step_param (:obj:`float`): parameter after one step |
|
""" |
|
assert isinstance(metrics, float), 'The metrics should be converted to a float number' |
|
cur_metrics = metrics |
|
|
|
if self.is_better(cur_metrics): |
|
self.bad_epochs_num = 0 |
|
else: |
|
self.bad_epochs_num += 1 |
|
self.last_metrics = cur_metrics |
|
|
|
if self.in_cooldown: |
|
self.cooldown_counter -= 1 |
|
self.bad_epochs_num = 0 |
|
|
|
if self.bad_epochs_num > self.patience: |
|
param = self.update_param(param) |
|
self.cooldown_counter = self.cooldown |
|
self.bad_epochs_num = 0 |
|
return param |
|
|
|
def update_param(self, param: float) -> float: |
|
""" |
|
Overview: |
|
update the scheduling parameter |
|
Args: |
|
- param (:obj:`float`): parameter need to be updated |
|
Returns: |
|
- updated param (:obj:`float`): parameter after updating |
|
""" |
|
schedule_fn = { |
|
'reduce': lambda x, y, z: max(x - y, z[0]), |
|
'add': lambda x, y, z: min(x + y, z[1]), |
|
'multi': lambda x, y, z: min(x * y, z[1]) if y >= 1 else max(x * y, z[0]), |
|
'div': lambda x, y, z: max(x / y, z[0]) if y >= 1 else min(x / y, z[1]), |
|
} |
|
|
|
schedule_mode_list = list(schedule_fn.keys()) |
|
|
|
if self.schedule_mode in schedule_mode_list: |
|
return schedule_fn[self.schedule_mode](param, self.factor, self.change_range) |
|
else: |
|
raise KeyError("invalid schedule_mode({}) in {}".format(self.schedule_mode, schedule_mode_list)) |
|
|
|
@property |
|
def in_cooldown(self) -> bool: |
|
""" |
|
Overview: |
|
Checks whether the scheduler is in cooldown peried. If in cooldown, the scheduler |
|
will ignore any bad epochs. |
|
""" |
|
return self.cooldown_counter > 0 |
|
|
|
def is_better(self, cur: float) -> bool: |
|
""" |
|
Overview: |
|
Checks whether the current metrics is better than last matric with respect to threshold. |
|
Args: |
|
- cur (:obj:`float`): current metrics |
|
""" |
|
if self.last_metrics is None: |
|
return True |
|
|
|
elif self.optimize_mode == 'min': |
|
return cur < self.last_metrics - self.threshold |
|
|
|
elif self.optimize_mode == 'max': |
|
return cur > self.last_metrics + self.threshold |
|
|