|
|
|
import argparse |
|
from collections import OrderedDict |
|
|
|
import torch |
|
|
|
|
|
def convert(src, dst): |
|
"""Convert keys in pretrained RTMDet models to MMYOLO style.""" |
|
blobs = torch.load(src)['state_dict'] |
|
state_dict = OrderedDict() |
|
|
|
for key, weight in blobs.items(): |
|
if 'neck.reduce_layers.0' in key: |
|
new_key = key.replace('.0', '.2') |
|
state_dict[new_key] = weight |
|
elif 'neck.reduce_layers.1' in key: |
|
new_key = key.replace('reduce_layers.1', 'top_down_layers.0.1') |
|
state_dict[new_key] = weight |
|
elif 'neck.top_down_blocks.0' in key: |
|
new_key = key.replace('down_blocks', 'down_layers.0') |
|
state_dict[new_key] = weight |
|
elif 'neck.top_down_blocks.1' in key: |
|
new_key = key.replace('down_blocks', 'down_layers') |
|
state_dict[new_key] = weight |
|
elif 'downsamples' in key: |
|
new_key = key.replace('downsamples', 'downsample_layers') |
|
state_dict[new_key] = weight |
|
elif 'bottom_up_blocks' in key: |
|
new_key = key.replace('bottom_up_blocks', 'bottom_up_layers') |
|
state_dict[new_key] = weight |
|
elif 'out_convs' in key: |
|
new_key = key.replace('out_convs', 'out_layers') |
|
state_dict[new_key] = weight |
|
elif 'bbox_head' in key: |
|
new_key = key.replace('bbox_head', 'bbox_head.head_module') |
|
state_dict[new_key] = weight |
|
elif 'data_preprocessor' in key: |
|
continue |
|
else: |
|
new_key = key |
|
state_dict[new_key] = weight |
|
print(f'Convert {key} to {new_key}') |
|
|
|
|
|
checkpoint = dict() |
|
checkpoint['state_dict'] = state_dict |
|
checkpoint['meta'] = blobs.get('meta') |
|
torch.save(checkpoint, dst) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Convert model keys') |
|
parser.add_argument('src', help='src rtm model path') |
|
parser.add_argument('dst', help='save path') |
|
args = parser.parse_args() |
|
convert(args.src, args.dst) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|