PKaushik commited on
Commit
b9ee486
1 Parent(s): d3c7726
Files changed (1) hide show
  1. yolov6/models/yolo.py +83 -0
yolov6/models/yolo.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ import math
4
+ import torch.nn as nn
5
+ from yolov6.layers.common import *
6
+ from yolov6.utils.torch_utils import initialize_weights
7
+ from yolov6.models.efficientrep import EfficientRep
8
+ from yolov6.models.reppan import RepPANNeck
9
+ from yolov6.models.effidehead import Detect, build_effidehead_layer
10
+
11
+
12
+ class Model(nn.Module):
13
+ '''YOLOv6 model with backbone, neck and head.
14
+ The default parts are EfficientRep Backbone, Rep-PAN and
15
+ Efficient Decoupled Head.
16
+ '''
17
+ def __init__(self, config, channels=3, num_classes=None, anchors=None): # model, input channels, number of classes
18
+ super().__init__()
19
+ # Build network
20
+ num_layers = config.model.head.num_layers
21
+ self.backbone, self.neck, self.detect = build_network(config, channels, num_classes, anchors, num_layers)
22
+
23
+ # Init Detect head
24
+ begin_indices = config.model.head.begin_indices
25
+ out_indices_head = config.model.head.out_indices
26
+ self.stride = self.detect.stride
27
+ self.detect.i = begin_indices
28
+ self.detect.f = out_indices_head
29
+ self.detect.initialize_biases()
30
+
31
+ # Init weights
32
+ initialize_weights(self)
33
+
34
+ def forward(self, x):
35
+ x = self.backbone(x)
36
+ x = self.neck(x)
37
+ x = self.detect(x)
38
+ return x
39
+
40
+ def _apply(self, fn):
41
+ self = super()._apply(fn)
42
+ self.detect.stride = fn(self.detect.stride)
43
+ self.detect.grid = list(map(fn, self.detect.grid))
44
+ return self
45
+
46
+
47
+ def make_divisible(x, divisor):
48
+ # Upward revision the value x to make it evenly divisible by the divisor.
49
+ return math.ceil(x / divisor) * divisor
50
+
51
+
52
+ def build_network(config, channels, num_classes, anchors, num_layers):
53
+ depth_mul = config.model.depth_multiple
54
+ width_mul = config.model.width_multiple
55
+ num_repeat_backbone = config.model.backbone.num_repeats
56
+ channels_list_backbone = config.model.backbone.out_channels
57
+ num_repeat_neck = config.model.neck.num_repeats
58
+ channels_list_neck = config.model.neck.out_channels
59
+ num_anchors = config.model.head.anchors
60
+ num_repeat = [(max(round(i * depth_mul), 1) if i > 1 else i) for i in (num_repeat_backbone + num_repeat_neck)]
61
+ channels_list = [make_divisible(i * width_mul, 8) for i in (channels_list_backbone + channels_list_neck)]
62
+
63
+ backbone = EfficientRep(
64
+ in_channels=channels,
65
+ channels_list=channels_list,
66
+ num_repeats=num_repeat
67
+ )
68
+
69
+ neck = RepPANNeck(
70
+ channels_list=channels_list,
71
+ num_repeats=num_repeat
72
+ )
73
+
74
+ head_layers = build_effidehead_layer(channels_list, num_anchors, num_classes)
75
+
76
+ head = Detect(num_classes, anchors, num_layers, head_layers=head_layers)
77
+
78
+ return backbone, neck, head
79
+
80
+
81
+ def build_model(cfg, num_classes, device):
82
+ model = Model(cfg, channels=3, num_classes=num_classes, anchors=cfg.model.head.anchors).to(device)
83
+ return model