Spaces:
Runtime error
Runtime error
File size: 5,150 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
from mmcv.parallel import is_module_wrapper
from mmcv.runner.hooks import HOOKS, Hook
class BaseEMAHook(Hook):
"""Exponential Moving Average Hook.
Use Exponential Moving Average on all parameters of model in training
process. All parameters have a ema backup, which update by the formula
as below. EMAHook takes priority over EvalHook and CheckpointHook. Note,
the original model parameters are actually saved in ema field after train.
Args:
momentum (float): The momentum used for updating ema parameter.
Ema's parameter are updated with the formula:
`ema_param = (1-momentum) * ema_param + momentum * cur_param`.
Defaults to 0.0002.
skip_buffers (bool): Whether to skip the model buffers, such as
batchnorm running stats (running_mean, running_var), it does not
perform the ema operation. Default to False.
interval (int): Update ema parameter every interval iteration.
Defaults to 1.
resume_from (str, optional): The checkpoint path. Defaults to None.
momentum_fun (func, optional): The function to change momentum
during early iteration (also warmup) to help early training.
It uses `momentum` as a constant. Defaults to None.
"""
def __init__(self,
momentum=0.0002,
interval=1,
skip_buffers=False,
resume_from=None,
momentum_fun=None):
assert 0 < momentum < 1
self.momentum = momentum
self.skip_buffers = skip_buffers
self.interval = interval
self.checkpoint = resume_from
self.momentum_fun = momentum_fun
def before_run(self, runner):
"""To resume model with it's ema parameters more friendly.
Register ema parameter as ``named_buffer`` to model.
"""
model = runner.model
if is_module_wrapper(model):
model = model.module
self.param_ema_buffer = {}
if self.skip_buffers:
self.model_parameters = dict(model.named_parameters())
else:
self.model_parameters = model.state_dict()
for name, value in self.model_parameters.items():
# "." is not allowed in module's buffer name
buffer_name = f"ema_{name.replace('.', '_')}"
self.param_ema_buffer[name] = buffer_name
model.register_buffer(buffer_name, value.data.clone())
self.model_buffers = dict(model.named_buffers())
if self.checkpoint is not None:
runner.resume(self.checkpoint)
def get_momentum(self, runner):
return self.momentum_fun(runner.iter) if self.momentum_fun else \
self.momentum
def after_train_iter(self, runner):
"""Update ema parameter every self.interval iterations."""
if (runner.iter + 1) % self.interval != 0:
return
momentum = self.get_momentum(runner)
for name, parameter in self.model_parameters.items():
# exclude num_tracking
if parameter.dtype.is_floating_point:
buffer_name = self.param_ema_buffer[name]
buffer_parameter = self.model_buffers[buffer_name]
buffer_parameter.mul_(1 - momentum).add_(
parameter.data, alpha=momentum)
def after_train_epoch(self, runner):
"""We load parameter values from ema backup to model before the
EvalHook."""
self._swap_ema_parameters()
def before_train_epoch(self, runner):
"""We recover model's parameter from ema backup after last epoch's
EvalHook."""
self._swap_ema_parameters()
def _swap_ema_parameters(self):
"""Swap the parameter of model with parameter in ema_buffer."""
for name, value in self.model_parameters.items():
temp = value.data.clone()
ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
value.data.copy_(ema_buffer.data)
ema_buffer.data.copy_(temp)
@HOOKS.register_module()
class ExpMomentumEMAHook(BaseEMAHook):
"""EMAHook using exponential momentum strategy.
Args:
total_iter (int): The total number of iterations of EMA momentum.
Defaults to 2000.
"""
def __init__(self, total_iter=2000, **kwargs):
super(ExpMomentumEMAHook, self).__init__(**kwargs)
self.momentum_fun = lambda x: (1 - self.momentum) * math.exp(-(
1 + x) / total_iter) + self.momentum
@HOOKS.register_module()
class LinearMomentumEMAHook(BaseEMAHook):
"""EMAHook using linear momentum strategy.
Args:
warm_up (int): During first warm_up steps, we may use smaller decay
to update ema parameters more slowly. Defaults to 100.
"""
def __init__(self, warm_up=100, **kwargs):
super(LinearMomentumEMAHook, self).__init__(**kwargs)
self.momentum_fun = lambda x: min(self.momentum**self.interval,
(1 + x) / (warm_up + x))
|