File size: 3,942 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))

import torch

from openrec.modeling import build_model
from openrec.postprocess import build_post_process
from tools.engine import Config
from tools.infer_rec import build_rec_process
from tools.utility import ArgsParser
from tools.utils.ckpt import load_ckpt
from tools.utils.logging import get_logger


def to_onnx(model, dummy_input, dynamic_axes, sava_path='model.onnx'):
    input_axis_name = ['batch_size', 'channel', 'in_width', 'int_height']
    output_axis_name = ['batch_size', 'channel', 'out_width', 'out_height']
    torch.onnx.export(
        model.to('cpu'),
        dummy_input,
        sava_path,
        input_names=['input'],
        output_names=['output'],  # the model's output names
        dynamic_axes={
            'input': {axis: input_axis_name[axis]
                      for axis in dynamic_axes},
            'output': {axis: output_axis_name[axis]
                       for axis in dynamic_axes},
        },
    )


def export_single_model(model: torch.nn.Module, _cfg, export_dir,
                        export_config, logger, type):
    for layer in model.modules():
        if hasattr(layer, 'rep') and not getattr(layer, 'is_repped'):
            layer.rep()
    os.makedirs(export_dir, exist_ok=True)

    export_cfg = {'PostProcess': _cfg['PostProcess']}
    export_cfg['Transforms'] = build_rec_process(_cfg)

    cfg.save(os.path.join(export_dir, 'config.yaml'), export_cfg)

    dummy_input = torch.randn(*export_config['export_shape'], device='cpu')
    if type == 'script':
        save_path = os.path.join(export_dir, 'model.pt')
        trace_model = torch.jit.trace(model, dummy_input, strict=False)
        torch.jit.save(trace_model, save_path)
    elif type == 'onnx':
        save_path = os.path.join(export_dir, 'model.onnx')
        to_onnx(model, dummy_input, export_config.get('dynamic_axes', []),
                save_path)
    else:
        raise NotImplementedError
    logger.info(f'finish export model to {save_path}')


def main(cfg, type):
    _cfg = cfg.cfg
    logger = get_logger()
    global_config = _cfg['Global']
    export_config = _cfg['Export']
    # build post process
    post_process_class = build_post_process(_cfg['PostProcess'])
    char_num = len(getattr(post_process_class, 'character'))
    cfg['Architecture']['Decoder']['out_channels'] = char_num
    model = build_model(_cfg['Architecture'])

    load_ckpt(model, _cfg)
    model.eval()

    export_dir = export_config.get('export_dir', '')
    if not export_dir:
        export_dir = os.path.join(global_config.get('output_dir', 'output'),
                                  'export')

    if _cfg['Architecture']['algorithm'] in ['Distillation'
                                             ]:  # distillation model
        _cfg['PostProcess'][
            'name'] = post_process_class.__class__.__base__.__name__
        for model_name in model.model_list:
            sub_model_save_path = os.path.join(export_dir, model_name)
            export_single_model(
                model.model_list[model_name],
                _cfg,
                sub_model_save_path,
                export_config,
                logger,
                type,
            )
    else:
        export_single_model(model, _cfg, export_dir, export_config, logger,
                            type)


def parse_args():
    parser = ArgsParser()
    parser.add_argument('--type',
                        type=str,
                        default='onnx',
                        help='type of export')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    FLAGS = parse_args()
    cfg = Config(FLAGS.config)
    FLAGS = vars(FLAGS)
    opt = FLAGS.pop('opt')
    cfg.merge_dict(FLAGS)
    cfg.merge_dict(opt)
    main(cfg, FLAGS['type'])