Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from collections import OrderedDict | |
def update_ema(ema_model, model, decay=0.9999): | |
""" | |
Step the EMA model towards the current model. | |
""" | |
ema_params = OrderedDict(ema_model.named_parameters()) | |
model_params = OrderedDict(model.named_parameters()) | |
for name, param in model_params.items(): | |
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed | |
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) | |
def requires_grad(model, flag=True): | |
""" | |
Set requires_grad flag for all parameters in a model. | |
""" | |
for p in model.parameters(): | |
p.requires_grad = flag |