Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# The code is based on | |
# https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py | |
import math | |
from copy import deepcopy | |
import torch | |
import torch.nn as nn | |
class ModelEMA: | |
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models | |
Keep a moving average of everything in the model state_dict (parameters and buffers). | |
This is intended to allow functionality like | |
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage | |
A smoothed version of the weights is necessary for some training schemes to perform well. | |
This class is sensitive where it is initialized in the sequence of model init, | |
GPU assignment and distributed training wrappers. | |
""" | |
def __init__(self, model, decay=0.9999, updates=0): | |
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA | |
self.updates = updates | |
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) | |
for param in self.ema.parameters(): | |
param.requires_grad_(False) | |
def update(self, model): | |
with torch.no_grad(): | |
self.updates += 1 | |
decay = self.decay(self.updates) | |
state_dict = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict | |
for k, item in self.ema.state_dict().items(): | |
if item.dtype.is_floating_point: | |
item *= decay | |
item += (1 - decay) * state_dict[k].detach() | |
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): | |
copy_attr(self.ema, model, include, exclude) | |
def copy_attr(a, b, include=(), exclude=()): | |
"""Copy attributes from one instance and set them to another instance.""" | |
for k, item in b.__dict__.items(): | |
if (len(include) and k not in include) or k.startswith('_') or k in exclude: | |
continue | |
else: | |
setattr(a, k, item) | |
def is_parallel(model): | |
# Return True if model's type is DP or DDP, else False. | |
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) | |
def de_parallel(model): | |
# De-parallelize a model. Return single-GPU model if model's type is DP or DDP. | |
return model.module if is_parallel(model) else model | |