Spaces:
Build error
Build error
import torch | |
from torch import nn | |
from yolov6.layers.common import RepBlock, SimConv, Transpose | |
class RepPANNeck(nn.Module): | |
"""RepPANNeck Module | |
EfficientRep is the default backbone of this model. | |
RepPANNeck has the balance of feature fusion ability and hardware efficiency. | |
""" | |
def __init__( | |
self, | |
channels_list=None, | |
num_repeats=None | |
): | |
super().__init__() | |
assert channels_list is not None | |
assert num_repeats is not None | |
self.Rep_p4 = RepBlock( | |
in_channels=channels_list[3] + channels_list[5], | |
out_channels=channels_list[5], | |
n=num_repeats[5], | |
) | |
self.Rep_p3 = RepBlock( | |
in_channels=channels_list[2] + channels_list[6], | |
out_channels=channels_list[6], | |
n=num_repeats[6] | |
) | |
self.Rep_n3 = RepBlock( | |
in_channels=channels_list[6] + channels_list[7], | |
out_channels=channels_list[8], | |
n=num_repeats[7], | |
) | |
self.Rep_n4 = RepBlock( | |
in_channels=channels_list[5] + channels_list[9], | |
out_channels=channels_list[10], | |
n=num_repeats[8] | |
) | |
self.reduce_layer0 = SimConv( | |
in_channels=channels_list[4], | |
out_channels=channels_list[5], | |
kernel_size=1, | |
stride=1 | |
) | |
self.upsample0 = Transpose( | |
in_channels=channels_list[5], | |
out_channels=channels_list[5], | |
) | |
self.reduce_layer1 = SimConv( | |
in_channels=channels_list[5], | |
out_channels=channels_list[6], | |
kernel_size=1, | |
stride=1 | |
) | |
self.upsample1 = Transpose( | |
in_channels=channels_list[6], | |
out_channels=channels_list[6] | |
) | |
self.downsample2 = SimConv( | |
in_channels=channels_list[6], | |
out_channels=channels_list[7], | |
kernel_size=3, | |
stride=2 | |
) | |
self.downsample1 = SimConv( | |
in_channels=channels_list[8], | |
out_channels=channels_list[9], | |
kernel_size=3, | |
stride=2 | |
) | |
def forward(self, input): | |
(x2, x1, x0) = input | |
fpn_out0 = self.reduce_layer0(x0) | |
upsample_feat0 = self.upsample0(fpn_out0) | |
f_concat_layer0 = torch.cat([upsample_feat0, x1], 1) | |
f_out0 = self.Rep_p4(f_concat_layer0) | |
fpn_out1 = self.reduce_layer1(f_out0) | |
upsample_feat1 = self.upsample1(fpn_out1) | |
f_concat_layer1 = torch.cat([upsample_feat1, x2], 1) | |
pan_out2 = self.Rep_p3(f_concat_layer1) | |
down_feat1 = self.downsample2(pan_out2) | |
p_concat_layer1 = torch.cat([down_feat1, fpn_out1], 1) | |
pan_out1 = self.Rep_n3(p_concat_layer1) | |
down_feat0 = self.downsample1(pan_out1) | |
p_concat_layer2 = torch.cat([down_feat0, fpn_out0], 1) | |
pan_out0 = self.Rep_n4(p_concat_layer2) | |
outputs = [pan_out2, pan_out1, pan_out0] | |
return outputs | |