Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import math | |
from yolov6.layers.common import * | |
class Detect(nn.Module): | |
'''Efficient Decoupled Head | |
With hardware-aware degisn, the decoupled head is optimized with | |
hybridchannels methods. | |
''' | |
def __init__(self, num_classes=80, anchors=1, num_layers=3, inplace=True, head_layers=None): # detection layer | |
super().__init__() | |
assert head_layers is not None | |
self.nc = num_classes # number of classes | |
self.no = num_classes + 5 # number of outputs per anchor | |
self.nl = num_layers # number of detection layers | |
if isinstance(anchors, (list, tuple)): | |
self.na = len(anchors[0]) // 2 | |
else: | |
self.na = anchors | |
self.anchors = anchors | |
self.grid = [torch.zeros(1)] * num_layers | |
self.prior_prob = 1e-2 | |
self.inplace = inplace | |
stride = [8, 16, 32] # strides computed during build | |
self.stride = torch.tensor(stride) | |
# Init decouple head | |
self.cls_convs = nn.ModuleList() | |
self.reg_convs = nn.ModuleList() | |
self.cls_preds = nn.ModuleList() | |
self.reg_preds = nn.ModuleList() | |
self.obj_preds = nn.ModuleList() | |
self.stems = nn.ModuleList() | |
# Efficient decoupled head layers | |
for i in range(num_layers): | |
idx = i*6 | |
self.stems.append(head_layers[idx]) | |
self.cls_convs.append(head_layers[idx+1]) | |
self.reg_convs.append(head_layers[idx+2]) | |
self.cls_preds.append(head_layers[idx+3]) | |
self.reg_preds.append(head_layers[idx+4]) | |
self.obj_preds.append(head_layers[idx+5]) | |
def initialize_biases(self): | |
for conv in self.cls_preds: | |
b = conv.bias.view(self.na, -1) | |
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob)) | |
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) | |
for conv in self.obj_preds: | |
b = conv.bias.view(self.na, -1) | |
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob)) | |
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) | |
def forward(self, x): | |
z = [] | |
for i in range(self.nl): | |
x[i] = self.stems[i](x[i]) | |
cls_x = x[i] | |
reg_x = x[i] | |
cls_feat = self.cls_convs[i](cls_x) | |
cls_output = self.cls_preds[i](cls_feat) | |
reg_feat = self.reg_convs[i](reg_x) | |
reg_output = self.reg_preds[i](reg_feat) | |
obj_output = self.obj_preds[i](reg_feat) | |
if self.training: | |
x[i] = torch.cat([reg_output, obj_output, cls_output], 1) | |
bs, _, ny, nx = x[i].shape | |
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() | |
else: | |
y = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1) | |
bs, _, ny, nx = y.shape | |
y = y.view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() | |
if self.grid[i].shape[2:4] != y.shape[2:4]: | |
d = self.stride.device | |
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)]) | |
self.grid[i] = torch.stack((xv, yv), 2).view(1, self.na, ny, nx, 2).float() | |
if self.inplace: | |
y[..., 0:2] = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy | |
y[..., 2:4] = torch.exp(y[..., 2:4]) * self.stride[i] # wh | |
else: | |
xy = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy | |
wh = torch.exp(y[..., 2:4]) * self.stride[i] # wh | |
y = torch.cat((xy, wh, y[..., 4:]), -1) | |
z.append(y.view(bs, -1, self.no)) | |
return x if self.training else torch.cat(z, 1) | |
def build_effidehead_layer(channels_list, num_anchors, num_classes): | |
head_layers = nn.Sequential( | |
# stem0 | |
Conv( | |
in_channels=channels_list[6], | |
out_channels=channels_list[6], | |
kernel_size=1, | |
stride=1 | |
), | |
# cls_conv0 | |
Conv( | |
in_channels=channels_list[6], | |
out_channels=channels_list[6], | |
kernel_size=3, | |
stride=1 | |
), | |
# reg_conv0 | |
Conv( | |
in_channels=channels_list[6], | |
out_channels=channels_list[6], | |
kernel_size=3, | |
stride=1 | |
), | |
# cls_pred0 | |
nn.Conv2d( | |
in_channels=channels_list[6], | |
out_channels=num_classes * num_anchors, | |
kernel_size=1 | |
), | |
# reg_pred0 | |
nn.Conv2d( | |
in_channels=channels_list[6], | |
out_channels=4 * num_anchors, | |
kernel_size=1 | |
), | |
# obj_pred0 | |
nn.Conv2d( | |
in_channels=channels_list[6], | |
out_channels=1 * num_anchors, | |
kernel_size=1 | |
), | |
# stem1 | |
Conv( | |
in_channels=channels_list[8], | |
out_channels=channels_list[8], | |
kernel_size=1, | |
stride=1 | |
), | |
# cls_conv1 | |
Conv( | |
in_channels=channels_list[8], | |
out_channels=channels_list[8], | |
kernel_size=3, | |
stride=1 | |
), | |
# reg_conv1 | |
Conv( | |
in_channels=channels_list[8], | |
out_channels=channels_list[8], | |
kernel_size=3, | |
stride=1 | |
), | |
# cls_pred1 | |
nn.Conv2d( | |
in_channels=channels_list[8], | |
out_channels=num_classes * num_anchors, | |
kernel_size=1 | |
), | |
# reg_pred1 | |
nn.Conv2d( | |
in_channels=channels_list[8], | |
out_channels=4 * num_anchors, | |
kernel_size=1 | |
), | |
# obj_pred1 | |
nn.Conv2d( | |
in_channels=channels_list[8], | |
out_channels=1 * num_anchors, | |
kernel_size=1 | |
), | |
# stem2 | |
Conv( | |
in_channels=channels_list[10], | |
out_channels=channels_list[10], | |
kernel_size=1, | |
stride=1 | |
), | |
# cls_conv2 | |
Conv( | |
in_channels=channels_list[10], | |
out_channels=channels_list[10], | |
kernel_size=3, | |
stride=1 | |
), | |
# reg_conv2 | |
Conv( | |
in_channels=channels_list[10], | |
out_channels=channels_list[10], | |
kernel_size=3, | |
stride=1 | |
), | |
# cls_pred2 | |
nn.Conv2d( | |
in_channels=channels_list[10], | |
out_channels=num_classes * num_anchors, | |
kernel_size=1 | |
), | |
# reg_pred2 | |
nn.Conv2d( | |
in_channels=channels_list[10], | |
out_channels=4 * num_anchors, | |
kernel_size=1 | |
), | |
# obj_pred2 | |
nn.Conv2d( | |
in_channels=channels_list[10], | |
out_channels=1 * num_anchors, | |
kernel_size=1 | |
) | |
) | |
return head_layers | |