# Copyright (c) OpenMMLab. All rights reserved. import math from mmcv.parallel import is_module_wrapper from mmcv.runner.hooks import HOOKS, Hook class BaseEMAHook(Hook): """Exponential Moving Average Hook. Use Exponential Moving Average on all parameters of model in training process. All parameters have a ema backup, which update by the formula as below. EMAHook takes priority over EvalHook and CheckpointHook. Note, the original model parameters are actually saved in ema field after train. Args: momentum (float): The momentum used for updating ema parameter. Ema's parameter are updated with the formula: `ema_param = (1-momentum) * ema_param + momentum * cur_param`. Defaults to 0.0002. skip_buffers (bool): Whether to skip the model buffers, such as batchnorm running stats (running_mean, running_var), it does not perform the ema operation. Default to False. interval (int): Update ema parameter every interval iteration. Defaults to 1. resume_from (str, optional): The checkpoint path. Defaults to None. momentum_fun (func, optional): The function to change momentum during early iteration (also warmup) to help early training. It uses `momentum` as a constant. Defaults to None. """ def __init__(self, momentum=0.0002, interval=1, skip_buffers=False, resume_from=None, momentum_fun=None): assert 0 < momentum < 1 self.momentum = momentum self.skip_buffers = skip_buffers self.interval = interval self.checkpoint = resume_from self.momentum_fun = momentum_fun def before_run(self, runner): """To resume model with it's ema parameters more friendly. Register ema parameter as ``named_buffer`` to model. """ model = runner.model if is_module_wrapper(model): model = model.module self.param_ema_buffer = {} if self.skip_buffers: self.model_parameters = dict(model.named_parameters()) else: self.model_parameters = model.state_dict() for name, value in self.model_parameters.items(): # "." is not allowed in module's buffer name buffer_name = f"ema_{name.replace('.', '_')}" self.param_ema_buffer[name] = buffer_name model.register_buffer(buffer_name, value.data.clone()) self.model_buffers = dict(model.named_buffers()) if self.checkpoint is not None: runner.resume(self.checkpoint) def get_momentum(self, runner): return self.momentum_fun(runner.iter) if self.momentum_fun else \ self.momentum def after_train_iter(self, runner): """Update ema parameter every self.interval iterations.""" if (runner.iter + 1) % self.interval != 0: return momentum = self.get_momentum(runner) for name, parameter in self.model_parameters.items(): # exclude num_tracking if parameter.dtype.is_floating_point: buffer_name = self.param_ema_buffer[name] buffer_parameter = self.model_buffers[buffer_name] buffer_parameter.mul_(1 - momentum).add_( parameter.data, alpha=momentum) def after_train_epoch(self, runner): """We load parameter values from ema backup to model before the EvalHook.""" self._swap_ema_parameters() def before_train_epoch(self, runner): """We recover model's parameter from ema backup after last epoch's EvalHook.""" self._swap_ema_parameters() def _swap_ema_parameters(self): """Swap the parameter of model with parameter in ema_buffer.""" for name, value in self.model_parameters.items(): temp = value.data.clone() ema_buffer = self.model_buffers[self.param_ema_buffer[name]] value.data.copy_(ema_buffer.data) ema_buffer.data.copy_(temp) @HOOKS.register_module() class ExpMomentumEMAHook(BaseEMAHook): """EMAHook using exponential momentum strategy. Args: total_iter (int): The total number of iterations of EMA momentum. Defaults to 2000. """ def __init__(self, total_iter=2000, **kwargs): super(ExpMomentumEMAHook, self).__init__(**kwargs) self.momentum_fun = lambda x: (1 - self.momentum) * math.exp(-( 1 + x) / total_iter) + self.momentum @HOOKS.register_module() class LinearMomentumEMAHook(BaseEMAHook): """EMAHook using linear momentum strategy. Args: warm_up (int): During first warm_up steps, we may use smaller decay to update ema parameters more slowly. Defaults to 100. """ def __init__(self, warm_up=100, **kwargs): super(LinearMomentumEMAHook, self).__init__(**kwargs) self.momentum_fun = lambda x: min(self.momentum**self.interval, (1 + x) / (warm_up + x))