chendl's picture
Add application file
0b7b08a
raw
history blame
5.61 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
import contextlib
from copy import deepcopy
from typing import Sequence
import torch
import torch.nn as nn
__all__ = [
"fuse_conv_and_bn",
"fuse_model",
"get_model_info",
"replace_module",
"freeze_module",
"adjust_status",
]
def get_model_info(model: nn.Module, tsize: Sequence[int]) -> str:
from thop import profile
stride = 64
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
params /= 1e6
flops /= 1e9
flops *= tsize[0] * tsize[1] / stride / stride * 2 # Gflops
info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops)
return info
def fuse_conv_and_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d:
"""
Fuse convolution and batchnorm layers.
check more info on https://tehnokv.com/posts/fusing-batchnorm-and-conv/
Args:
conv (nn.Conv2d): convolution to fuse.
bn (nn.BatchNorm2d): batchnorm to fuse.
Returns:
nn.Conv2d: fused convolution behaves the same as the input conv and bn.
"""
fusedconv = (
nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)
# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
# prepare spatial bias
b_conv = (
torch.zeros(conv.weight.size(0), device=conv.weight.device)
if conv.bias is None
else conv.bias
)
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
torch.sqrt(bn.running_var + bn.eps)
)
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv
def fuse_model(model: nn.Module) -> nn.Module:
"""fuse conv and bn in model
Args:
model (nn.Module): model to fuse
Returns:
nn.Module: fused model
"""
from yolox.models.network_blocks import BaseConv
for m in model.modules():
if type(m) is BaseConv and hasattr(m, "bn"):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, "bn") # remove batchnorm
m.forward = m.fuseforward # update forward
return model
def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
"""
Replace given type in module to a new type. mostly used in deploy.
Args:
module (nn.Module): model to apply replace operation.
replaced_module_type (Type): module type to be replaced.
new_module_type (Type)
replace_func (function): python function to describe replace logic. Defalut value None.
Returns:
model (nn.Module): module that already been replaced.
"""
def default_replace_func(replaced_module_type, new_module_type):
return new_module_type()
if replace_func is None:
replace_func = default_replace_func
model = module
if isinstance(module, replaced_module_type):
model = replace_func(replaced_module_type, new_module_type)
else: # recurrsively replace
for name, child in module.named_children():
new_child = replace_module(child, replaced_module_type, new_module_type)
if new_child is not child: # child is already replaced
model.add_module(name, new_child)
return model
def freeze_module(module: nn.Module, name=None) -> nn.Module:
"""freeze module inplace
Args:
module (nn.Module): module to freeze.
name (str, optional): name to freeze. If not given, freeze the whole module.
Note that fuzzy match is not supported. Defaults to None.
Examples:
freeze the backbone of model
>>> freeze_moudle(model.backbone)
or freeze the backbone of model by name
>>> freeze_moudle(model, name="backbone")
"""
for param_name, parameter in module.named_parameters():
if name is None or name in param_name:
parameter.requires_grad = False
# ensure module like BN and dropout are freezed
for module_name, sub_module in module.named_modules():
# actually there are no needs to call eval for every single sub_module
if name is None or name in module_name:
sub_module.eval()
return module
@contextlib.contextmanager
def adjust_status(module: nn.Module, training: bool = False) -> nn.Module:
"""Adjust module to training/eval mode temporarily.
Args:
module (nn.Module): module to adjust status.
training (bool): training mode to set. True for train mode, False fro eval mode.
Examples:
>>> with adjust_status(model, training=False):
... model(data)
"""
status = {}
def backup_status(module):
for m in module.modules():
# save prev status to dict
status[m] = m.training
m.training = training
def recover_status(module):
for m in module.modules():
# recover prev status from dict
m.training = status.pop(m)
backup_status(module)
yield module
recover_status(module)