Spaces:
Runtime error
Runtime error
from typing import Dict, Any | |
import torch | |
class Scheduler: | |
""" Parameter Scheduler Base Class | |
A scheduler base class that can be used to schedule any optimizer parameter groups. | |
Unlike the builtin PyTorch schedulers, this is intended to be consistently called | |
* At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value | |
* At the END of each optimizer update, after incrementing the update count, to calculate next update's value | |
The schedulers built on this should try to remain as stateless as possible (for simplicity). | |
This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' | |
and -1 values for special behaviour. All epoch and update counts must be tracked in the training | |
code and explicitly passed in to the schedulers on the corresponding step or step_update call. | |
Based on ideas from: | |
* https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler | |
* https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers | |
""" | |
def __init__(self, | |
optimizer: torch.optim.Optimizer, | |
param_group_field: str, | |
noise_range_t=None, | |
noise_type='normal', | |
noise_pct=0.67, | |
noise_std=1.0, | |
noise_seed=None, | |
initialize: bool = True) -> None: | |
self.optimizer = optimizer | |
self.param_group_field = param_group_field | |
self._initial_param_group_field = f"initial_{param_group_field}" | |
if initialize: | |
for i, group in enumerate(self.optimizer.param_groups): | |
if param_group_field not in group: | |
raise KeyError(f"{param_group_field} missing from param_groups[{i}]") | |
group.setdefault(self._initial_param_group_field, group[param_group_field]) | |
else: | |
for i, group in enumerate(self.optimizer.param_groups): | |
if self._initial_param_group_field not in group: | |
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") | |
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] | |
self.metric = None # any point to having this for all? | |
self.noise_range_t = noise_range_t | |
self.noise_pct = noise_pct | |
self.noise_type = noise_type | |
self.noise_std = noise_std | |
self.noise_seed = noise_seed if noise_seed is not None else 42 | |
self.update_groups(self.base_values) | |
def state_dict(self) -> Dict[str, Any]: | |
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} | |
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
self.__dict__.update(state_dict) | |
def get_epoch_values(self, epoch: int): | |
return None | |
def get_update_values(self, num_updates: int): | |
return None | |
def step(self, epoch: int, metric: float = None) -> None: | |
self.metric = metric | |
values = self.get_epoch_values(epoch) | |
if values is not None: | |
values = self._add_noise(values, epoch) | |
self.update_groups(values) | |
def step_update(self, num_updates: int, metric: float = None): | |
self.metric = metric | |
values = self.get_update_values(num_updates) | |
if values is not None: | |
values = self._add_noise(values, num_updates) | |
self.update_groups(values) | |
def update_groups(self, values): | |
if not isinstance(values, (list, tuple)): | |
values = [values] * len(self.optimizer.param_groups) | |
for param_group, value in zip(self.optimizer.param_groups, values): | |
param_group[self.param_group_field] = value | |
def _add_noise(self, lrs, t): | |
if self.noise_range_t is not None: | |
if isinstance(self.noise_range_t, (list, tuple)): | |
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] | |
else: | |
apply_noise = t >= self.noise_range_t | |
if apply_noise: | |
g = torch.Generator() | |
g.manual_seed(self.noise_seed + t) | |
if self.noise_type == 'normal': | |
while True: | |
# resample if noise out of percent limit, brute force but shouldn't spin much | |
noise = torch.randn(1, generator=g).item() | |
if abs(noise) < self.noise_pct: | |
break | |
else: | |
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct | |
lrs = [v + v * noise for v in lrs] | |
return lrs | |