File size: 4,403 Bytes
186701e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
from collections import OrderedDict

import torch


def convert(src, dst):
    import sys
    sys.path.append('yolov6')
    try:
        ckpt = torch.load(src, map_location=torch.device('cpu'))
    except ModuleNotFoundError:
        raise RuntimeError(
            'This script must be placed under the meituan/YOLOv6 repo,'
            ' because loading the official pretrained model need'
            ' some python files to build model.')
    # The saved model is the model before reparameterization
    model = ckpt['ema' if ckpt.get('ema') else 'model'].float()
    new_state_dict = OrderedDict()
    for k, v in model.state_dict().items():
        name = k
        if 'detect' in k:
            if 'proj' in k:
                continue
            name = k.replace('detect', 'bbox_head.head_module')
        if k.find('anchors') >= 0 or k.find('anchor_grid') >= 0:
            continue

        if 'ERBlock_2' in k:
            name = k.replace('ERBlock_2', 'stage1.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'ERBlock_3' in k:
            name = k.replace('ERBlock_3', 'stage2.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'ERBlock_4' in k:
            name = k.replace('ERBlock_4', 'stage3.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'ERBlock_5' in k:
            name = k.replace('ERBlock_5', 'stage4.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
            if 'stage4.0.2' in name:
                name = name.replace('stage4.0.2', 'stage4.1')
                name = name.replace('cv', 'conv')
        elif 'reduce_layer0' in k:
            name = k.replace('reduce_layer0', 'reduce_layers.2')
        elif 'Rep_p4' in k:
            name = k.replace('Rep_p4', 'top_down_layers.0.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'reduce_layer1' in k:
            name = k.replace('reduce_layer1', 'top_down_layers.0.1')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'Rep_p3' in k:
            name = k.replace('Rep_p3', 'top_down_layers.1')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'upsample0' in k:
            name = k.replace('upsample0.upsample_transpose',
                             'upsample_layers.0')
        elif 'upsample1' in k:
            name = k.replace('upsample1.upsample_transpose',
                             'upsample_layers.1')
        elif 'Rep_n3' in k:
            name = k.replace('Rep_n3', 'bottom_up_layers.0')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'Rep_n4' in k:
            name = k.replace('Rep_n4', 'bottom_up_layers.1')
            if '.cv' in k:
                name = name.replace('.cv', '.conv')
            if '.m.' in k:
                name = name.replace('.m.', '.block.')
        elif 'downsample2' in k:
            name = k.replace('downsample2', 'downsample_layers.0')
        elif 'downsample1' in k:
            name = k.replace('downsample1', 'downsample_layers.1')

        new_state_dict[name] = v
    data = {'state_dict': new_state_dict}
    torch.save(data, dst)


# Note: This script must be placed under the yolov6 repo to run.
def main():
    parser = argparse.ArgumentParser(description='Convert model keys')
    parser.add_argument(
        '--src', default='yolov6s.pt', help='src yolov6 model path')
    parser.add_argument('--dst', default='mmyolov6.pt', help='save path')
    args = parser.parse_args()
    convert(args.src, args.dst)


if __name__ == '__main__':
    main()