Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
from mmcv.parallel import is_module_wrapper | |
from mmcv.runner.hooks import HOOKS, Hook | |
class BaseEMAHook(Hook): | |
"""Exponential Moving Average Hook. | |
Use Exponential Moving Average on all parameters of model in training | |
process. All parameters have a ema backup, which update by the formula | |
as below. EMAHook takes priority over EvalHook and CheckpointHook. Note, | |
the original model parameters are actually saved in ema field after train. | |
Args: | |
momentum (float): The momentum used for updating ema parameter. | |
Ema's parameter are updated with the formula: | |
`ema_param = (1-momentum) * ema_param + momentum * cur_param`. | |
Defaults to 0.0002. | |
skip_buffers (bool): Whether to skip the model buffers, such as | |
batchnorm running stats (running_mean, running_var), it does not | |
perform the ema operation. Default to False. | |
interval (int): Update ema parameter every interval iteration. | |
Defaults to 1. | |
resume_from (str, optional): The checkpoint path. Defaults to None. | |
momentum_fun (func, optional): The function to change momentum | |
during early iteration (also warmup) to help early training. | |
It uses `momentum` as a constant. Defaults to None. | |
""" | |
def __init__(self, | |
momentum=0.0002, | |
interval=1, | |
skip_buffers=False, | |
resume_from=None, | |
momentum_fun=None): | |
assert 0 < momentum < 1 | |
self.momentum = momentum | |
self.skip_buffers = skip_buffers | |
self.interval = interval | |
self.checkpoint = resume_from | |
self.momentum_fun = momentum_fun | |
def before_run(self, runner): | |
"""To resume model with it's ema parameters more friendly. | |
Register ema parameter as ``named_buffer`` to model. | |
""" | |
model = runner.model | |
if is_module_wrapper(model): | |
model = model.module | |
self.param_ema_buffer = {} | |
if self.skip_buffers: | |
self.model_parameters = dict(model.named_parameters()) | |
else: | |
self.model_parameters = model.state_dict() | |
for name, value in self.model_parameters.items(): | |
# "." is not allowed in module's buffer name | |
buffer_name = f"ema_{name.replace('.', '_')}" | |
self.param_ema_buffer[name] = buffer_name | |
model.register_buffer(buffer_name, value.data.clone()) | |
self.model_buffers = dict(model.named_buffers()) | |
if self.checkpoint is not None: | |
runner.resume(self.checkpoint) | |
def get_momentum(self, runner): | |
return self.momentum_fun(runner.iter) if self.momentum_fun else \ | |
self.momentum | |
def after_train_iter(self, runner): | |
"""Update ema parameter every self.interval iterations.""" | |
if (runner.iter + 1) % self.interval != 0: | |
return | |
momentum = self.get_momentum(runner) | |
for name, parameter in self.model_parameters.items(): | |
# exclude num_tracking | |
if parameter.dtype.is_floating_point: | |
buffer_name = self.param_ema_buffer[name] | |
buffer_parameter = self.model_buffers[buffer_name] | |
buffer_parameter.mul_(1 - momentum).add_( | |
parameter.data, alpha=momentum) | |
def after_train_epoch(self, runner): | |
"""We load parameter values from ema backup to model before the | |
EvalHook.""" | |
self._swap_ema_parameters() | |
def before_train_epoch(self, runner): | |
"""We recover model's parameter from ema backup after last epoch's | |
EvalHook.""" | |
self._swap_ema_parameters() | |
def _swap_ema_parameters(self): | |
"""Swap the parameter of model with parameter in ema_buffer.""" | |
for name, value in self.model_parameters.items(): | |
temp = value.data.clone() | |
ema_buffer = self.model_buffers[self.param_ema_buffer[name]] | |
value.data.copy_(ema_buffer.data) | |
ema_buffer.data.copy_(temp) | |
class ExpMomentumEMAHook(BaseEMAHook): | |
"""EMAHook using exponential momentum strategy. | |
Args: | |
total_iter (int): The total number of iterations of EMA momentum. | |
Defaults to 2000. | |
""" | |
def __init__(self, total_iter=2000, **kwargs): | |
super(ExpMomentumEMAHook, self).__init__(**kwargs) | |
self.momentum_fun = lambda x: (1 - self.momentum) * math.exp(-( | |
1 + x) / total_iter) + self.momentum | |
class LinearMomentumEMAHook(BaseEMAHook): | |
"""EMAHook using linear momentum strategy. | |
Args: | |
warm_up (int): During first warm_up steps, we may use smaller decay | |
to update ema parameters more slowly. Defaults to 100. | |
""" | |
def __init__(self, warm_up=100, **kwargs): | |
super(LinearMomentumEMAHook, self).__init__(**kwargs) | |
self.momentum_fun = lambda x: min(self.momentum**self.interval, | |
(1 + x) / (warm_up + x)) | |