Spaces:
Build error
Build error
File size: 3,054 Bytes
6d070d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
from torch import nn
from yolov6.layers.common import RepBlock, SimConv, Transpose
class RepPANNeck(nn.Module):
"""RepPANNeck Module
EfficientRep is the default backbone of this model.
RepPANNeck has the balance of feature fusion ability and hardware efficiency.
"""
def __init__(
self,
channels_list=None,
num_repeats=None
):
super().__init__()
assert channels_list is not None
assert num_repeats is not None
self.Rep_p4 = RepBlock(
in_channels=channels_list[3] + channels_list[5],
out_channels=channels_list[5],
n=num_repeats[5],
)
self.Rep_p3 = RepBlock(
in_channels=channels_list[2] + channels_list[6],
out_channels=channels_list[6],
n=num_repeats[6]
)
self.Rep_n3 = RepBlock(
in_channels=channels_list[6] + channels_list[7],
out_channels=channels_list[8],
n=num_repeats[7],
)
self.Rep_n4 = RepBlock(
in_channels=channels_list[5] + channels_list[9],
out_channels=channels_list[10],
n=num_repeats[8]
)
self.reduce_layer0 = SimConv(
in_channels=channels_list[4],
out_channels=channels_list[5],
kernel_size=1,
stride=1
)
self.upsample0 = Transpose(
in_channels=channels_list[5],
out_channels=channels_list[5],
)
self.reduce_layer1 = SimConv(
in_channels=channels_list[5],
out_channels=channels_list[6],
kernel_size=1,
stride=1
)
self.upsample1 = Transpose(
in_channels=channels_list[6],
out_channels=channels_list[6]
)
self.downsample2 = SimConv(
in_channels=channels_list[6],
out_channels=channels_list[7],
kernel_size=3,
stride=2
)
self.downsample1 = SimConv(
in_channels=channels_list[8],
out_channels=channels_list[9],
kernel_size=3,
stride=2
)
def forward(self, input):
(x2, x1, x0) = input
fpn_out0 = self.reduce_layer0(x0)
upsample_feat0 = self.upsample0(fpn_out0)
f_concat_layer0 = torch.cat([upsample_feat0, x1], 1)
f_out0 = self.Rep_p4(f_concat_layer0)
fpn_out1 = self.reduce_layer1(f_out0)
upsample_feat1 = self.upsample1(fpn_out1)
f_concat_layer1 = torch.cat([upsample_feat1, x2], 1)
pan_out2 = self.Rep_p3(f_concat_layer1)
down_feat1 = self.downsample2(pan_out2)
p_concat_layer1 = torch.cat([down_feat1, fpn_out1], 1)
pan_out1 = self.Rep_n3(p_concat_layer1)
down_feat0 = self.downsample1(pan_out1)
p_concat_layer2 = torch.cat([down_feat0, fpn_out0], 1)
pan_out0 = self.Rep_n4(p_concat_layer2)
outputs = [pan_out2, pan_out1, pan_out0]
return outputs
|