Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# Copyright (c) Megvii Inc. All rights reserved. | |
import contextlib | |
from copy import deepcopy | |
from typing import Sequence | |
import torch | |
import torch.nn as nn | |
__all__ = [ | |
"fuse_conv_and_bn", | |
"fuse_model", | |
"get_model_info", | |
"replace_module", | |
"freeze_module", | |
"adjust_status", | |
] | |
def get_model_info(model: nn.Module, tsize: Sequence[int]) -> str: | |
from thop import profile | |
stride = 64 | |
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) | |
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False) | |
params /= 1e6 | |
flops /= 1e9 | |
flops *= tsize[0] * tsize[1] / stride / stride * 2 # Gflops | |
info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops) | |
return info | |
def fuse_conv_and_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d: | |
""" | |
Fuse convolution and batchnorm layers. | |
check more info on https://tehnokv.com/posts/fusing-batchnorm-and-conv/ | |
Args: | |
conv (nn.Conv2d): convolution to fuse. | |
bn (nn.BatchNorm2d): batchnorm to fuse. | |
Returns: | |
nn.Conv2d: fused convolution behaves the same as the input conv and bn. | |
""" | |
fusedconv = ( | |
nn.Conv2d( | |
conv.in_channels, | |
conv.out_channels, | |
kernel_size=conv.kernel_size, | |
stride=conv.stride, | |
padding=conv.padding, | |
groups=conv.groups, | |
bias=True, | |
) | |
.requires_grad_(False) | |
.to(conv.weight.device) | |
) | |
# prepare filters | |
w_conv = conv.weight.clone().view(conv.out_channels, -1) | |
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) | |
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) | |
# prepare spatial bias | |
b_conv = ( | |
torch.zeros(conv.weight.size(0), device=conv.weight.device) | |
if conv.bias is None | |
else conv.bias | |
) | |
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div( | |
torch.sqrt(bn.running_var + bn.eps) | |
) | |
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) | |
return fusedconv | |
def fuse_model(model: nn.Module) -> nn.Module: | |
"""fuse conv and bn in model | |
Args: | |
model (nn.Module): model to fuse | |
Returns: | |
nn.Module: fused model | |
""" | |
from yolox.models.network_blocks import BaseConv | |
for m in model.modules(): | |
if type(m) is BaseConv and hasattr(m, "bn"): | |
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv | |
delattr(m, "bn") # remove batchnorm | |
m.forward = m.fuseforward # update forward | |
return model | |
def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module: | |
""" | |
Replace given type in module to a new type. mostly used in deploy. | |
Args: | |
module (nn.Module): model to apply replace operation. | |
replaced_module_type (Type): module type to be replaced. | |
new_module_type (Type) | |
replace_func (function): python function to describe replace logic. Defalut value None. | |
Returns: | |
model (nn.Module): module that already been replaced. | |
""" | |
def default_replace_func(replaced_module_type, new_module_type): | |
return new_module_type() | |
if replace_func is None: | |
replace_func = default_replace_func | |
model = module | |
if isinstance(module, replaced_module_type): | |
model = replace_func(replaced_module_type, new_module_type) | |
else: # recurrsively replace | |
for name, child in module.named_children(): | |
new_child = replace_module(child, replaced_module_type, new_module_type) | |
if new_child is not child: # child is already replaced | |
model.add_module(name, new_child) | |
return model | |
def freeze_module(module: nn.Module, name=None) -> nn.Module: | |
"""freeze module inplace | |
Args: | |
module (nn.Module): module to freeze. | |
name (str, optional): name to freeze. If not given, freeze the whole module. | |
Note that fuzzy match is not supported. Defaults to None. | |
Examples: | |
freeze the backbone of model | |
>>> freeze_moudle(model.backbone) | |
or freeze the backbone of model by name | |
>>> freeze_moudle(model, name="backbone") | |
""" | |
for param_name, parameter in module.named_parameters(): | |
if name is None or name in param_name: | |
parameter.requires_grad = False | |
# ensure module like BN and dropout are freezed | |
for module_name, sub_module in module.named_modules(): | |
# actually there are no needs to call eval for every single sub_module | |
if name is None or name in module_name: | |
sub_module.eval() | |
return module | |
def adjust_status(module: nn.Module, training: bool = False) -> nn.Module: | |
"""Adjust module to training/eval mode temporarily. | |
Args: | |
module (nn.Module): module to adjust status. | |
training (bool): training mode to set. True for train mode, False fro eval mode. | |
Examples: | |
>>> with adjust_status(model, training=False): | |
... model(data) | |
""" | |
status = {} | |
def backup_status(module): | |
for m in module.modules(): | |
# save prev status to dict | |
status[m] = m.training | |
m.training = training | |
def recover_status(module): | |
for m in module.modules(): | |
# recover prev status from dict | |
m.training = status.pop(m) | |
backup_status(module) | |
yield module | |
recover_status(module) | |