OpenOCR-Demo / tools /export_rec.py
topdu's picture
openocr demo
29f689c
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
import torch
from openrec.modeling import build_model
from openrec.postprocess import build_post_process
from tools.engine import Config
from tools.infer_rec import build_rec_process
from tools.utility import ArgsParser
from tools.utils.ckpt import load_ckpt
from tools.utils.logging import get_logger
def to_onnx(model, dummy_input, dynamic_axes, sava_path='model.onnx'):
input_axis_name = ['batch_size', 'channel', 'in_width', 'int_height']
output_axis_name = ['batch_size', 'channel', 'out_width', 'out_height']
torch.onnx.export(
model.to('cpu'),
dummy_input,
sava_path,
input_names=['input'],
output_names=['output'], # the model's output names
dynamic_axes={
'input': {axis: input_axis_name[axis]
for axis in dynamic_axes},
'output': {axis: output_axis_name[axis]
for axis in dynamic_axes},
},
)
def export_single_model(model: torch.nn.Module, _cfg, export_dir,
export_config, logger, type):
for layer in model.modules():
if hasattr(layer, 'rep') and not getattr(layer, 'is_repped'):
layer.rep()
os.makedirs(export_dir, exist_ok=True)
export_cfg = {'PostProcess': _cfg['PostProcess']}
export_cfg['Transforms'] = build_rec_process(_cfg)
cfg.save(os.path.join(export_dir, 'config.yaml'), export_cfg)
dummy_input = torch.randn(*export_config['export_shape'], device='cpu')
if type == 'script':
save_path = os.path.join(export_dir, 'model.pt')
trace_model = torch.jit.trace(model, dummy_input, strict=False)
torch.jit.save(trace_model, save_path)
elif type == 'onnx':
save_path = os.path.join(export_dir, 'model.onnx')
to_onnx(model, dummy_input, export_config.get('dynamic_axes', []),
save_path)
else:
raise NotImplementedError
logger.info(f'finish export model to {save_path}')
def main(cfg, type):
_cfg = cfg.cfg
logger = get_logger()
global_config = _cfg['Global']
export_config = _cfg['Export']
# build post process
post_process_class = build_post_process(_cfg['PostProcess'])
char_num = len(getattr(post_process_class, 'character'))
cfg['Architecture']['Decoder']['out_channels'] = char_num
model = build_model(_cfg['Architecture'])
load_ckpt(model, _cfg)
model.eval()
export_dir = export_config.get('export_dir', '')
if not export_dir:
export_dir = os.path.join(global_config.get('output_dir', 'output'),
'export')
if _cfg['Architecture']['algorithm'] in ['Distillation'
]: # distillation model
_cfg['PostProcess'][
'name'] = post_process_class.__class__.__base__.__name__
for model_name in model.model_list:
sub_model_save_path = os.path.join(export_dir, model_name)
export_single_model(
model.model_list[model_name],
_cfg,
sub_model_save_path,
export_config,
logger,
type,
)
else:
export_single_model(model, _cfg, export_dir, export_config, logger,
type)
def parse_args():
parser = ArgsParser()
parser.add_argument('--type',
type=str,
default='onnx',
help='type of export')
args = parser.parse_args()
return args
if __name__ == '__main__':
FLAGS = parse_args()
cfg = Config(FLAGS.config)
FLAGS = vars(FLAGS)
opt = FLAGS.pop('opt')
cfg.merge_dict(FLAGS)
cfg.merge_dict(opt)
main(cfg, FLAGS['type'])