Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
import time | |
from contextlib import contextmanager | |
from copy import deepcopy | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from yolov6.utils.events import LOGGER | |
try: | |
import thop # for FLOPs computation | |
except ImportError: | |
thop = None | |
def torch_distributed_zero_first(local_rank: int): | |
""" | |
Decorator to make all processes in distributed training wait for each local_master to do something. | |
""" | |
if local_rank not in [-1, 0]: | |
dist.barrier(device_ids=[local_rank]) | |
yield | |
if local_rank == 0: | |
dist.barrier(device_ids=[0]) | |
def time_sync(): | |
# Waits for all kernels in all streams on a CUDA device to complete if cuda is available. | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
return time.time() | |
def initialize_weights(model): | |
for m in model.modules(): | |
t = type(m) | |
if t is nn.Conv2d: | |
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
elif t is nn.BatchNorm2d: | |
m.eps = 1e-3 | |
m.momentum = 0.03 | |
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: | |
m.inplace = True | |
def fuse_conv_and_bn(conv, bn): | |
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ | |
fusedconv = ( | |
nn.Conv2d( | |
conv.in_channels, | |
conv.out_channels, | |
kernel_size=conv.kernel_size, | |
stride=conv.stride, | |
padding=conv.padding, | |
groups=conv.groups, | |
bias=True, | |
) | |
.requires_grad_(False) | |
.to(conv.weight.device) | |
) | |
# prepare filters | |
w_conv = conv.weight.clone().view(conv.out_channels, -1) | |
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) | |
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) | |
# prepare spatial bias | |
b_conv = ( | |
torch.zeros(conv.weight.size(0), device=conv.weight.device) | |
if conv.bias is None | |
else conv.bias | |
) | |
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div( | |
torch.sqrt(bn.running_var + bn.eps) | |
) | |
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) | |
return fusedconv | |
def fuse_model(model): | |
from yolov6.layers.common import Conv | |
for m in model.modules(): | |
if type(m) is Conv and hasattr(m, "bn"): | |
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv | |
delattr(m, "bn") # remove batchnorm | |
m.forward = m.forward_fuse # update forward | |
return model | |
def get_model_info(model, img_size=640): | |
"""Get model Params and GFlops. | |
Code base on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/model_utils.py | |
""" | |
from thop import profile | |
stride = 32 | |
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) | |
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False) | |
params /= 1e6 | |
flops /= 1e9 | |
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] | |
flops *= img_size[0] * img_size[1] / stride / stride * 2 # Gflops | |
info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops) | |
return info | |