|
import argparse |
|
from pathlib import Path |
|
from typing import List, Optional, Tuple, Union |
|
|
|
try: |
|
import tensorrt as trt |
|
except Exception: |
|
trt = None |
|
import warnings |
|
|
|
import numpy as np |
|
import torch |
|
|
|
warnings.filterwarnings(action='ignore', category=DeprecationWarning) |
|
|
|
|
|
class EngineBuilder: |
|
|
|
def __init__( |
|
self, |
|
checkpoint: Union[str, Path], |
|
opt_shape: Union[Tuple, List] = (1, 3, 640, 640), |
|
device: Optional[Union[str, int, torch.device]] = None) -> None: |
|
checkpoint = Path(checkpoint) if isinstance(checkpoint, |
|
str) else checkpoint |
|
assert checkpoint.exists() and checkpoint.suffix == '.onnx' |
|
if isinstance(device, str): |
|
device = torch.device(device) |
|
elif isinstance(device, int): |
|
device = torch.device(f'cuda:{device}') |
|
|
|
self.checkpoint = checkpoint |
|
self.opt_shape = np.array(opt_shape, dtype=np.float32) |
|
self.device = device |
|
|
|
def __build_engine(self, |
|
scale: Optional[List[List]] = None, |
|
fp16: bool = True, |
|
with_profiling: bool = True) -> None: |
|
logger = trt.Logger(trt.Logger.WARNING) |
|
trt.init_libnvinfer_plugins(logger, namespace='') |
|
builder = trt.Builder(logger) |
|
config = builder.create_builder_config() |
|
config.max_workspace_size = torch.cuda.get_device_properties( |
|
self.device).total_memory |
|
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) |
|
network = builder.create_network(flag) |
|
parser = trt.OnnxParser(network, logger) |
|
if not parser.parse_from_file(str(self.checkpoint)): |
|
raise RuntimeError( |
|
f'failed to load ONNX file: {str(self.checkpoint)}') |
|
inputs = [network.get_input(i) for i in range(network.num_inputs)] |
|
outputs = [network.get_output(i) for i in range(network.num_outputs)] |
|
profile = None |
|
dshape = -1 in network.get_input(0).shape |
|
if dshape: |
|
profile = builder.create_optimization_profile() |
|
if scale is None: |
|
scale = np.array( |
|
[[1, 1, 0.5, 0.5], [1, 1, 1, 1], [4, 1, 1.5, 1.5]], |
|
dtype=np.float32) |
|
scale = (self.opt_shape * scale).astype(np.int32) |
|
elif isinstance(scale, List): |
|
scale = np.array(scale, dtype=np.int32) |
|
assert scale.shape[0] == 3, 'Input a wrong scale list' |
|
else: |
|
raise NotImplementedError |
|
|
|
for inp in inputs: |
|
logger.log( |
|
trt.Logger.WARNING, |
|
f'input "{inp.name}" with shape{inp.shape} {inp.dtype}') |
|
if dshape: |
|
profile.set_shape(inp.name, *scale) |
|
for out in outputs: |
|
logger.log( |
|
trt.Logger.WARNING, |
|
f'output "{out.name}" with shape{out.shape} {out.dtype}') |
|
if fp16 and builder.platform_has_fast_fp16: |
|
config.set_flag(trt.BuilderFlag.FP16) |
|
self.weight = self.checkpoint.with_suffix('.engine') |
|
if dshape: |
|
config.add_optimization_profile(profile) |
|
if with_profiling: |
|
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED |
|
with builder.build_engine(network, config) as engine: |
|
self.weight.write_bytes(engine.serialize()) |
|
logger.log( |
|
trt.Logger.WARNING, f'Build tensorrt engine finish.\n' |
|
f'Save in {str(self.weight.absolute())}') |
|
|
|
def build(self, |
|
scale: Optional[List[List]] = None, |
|
fp16: bool = True, |
|
with_profiling=True): |
|
self.__build_engine(scale, fp16, with_profiling) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('checkpoint', help='Checkpoint file') |
|
parser.add_argument( |
|
'--img-size', |
|
nargs='+', |
|
type=int, |
|
default=[640, 640], |
|
help='Image size of height and width') |
|
parser.add_argument( |
|
'--device', type=str, default='cuda:0', help='TensorRT builder device') |
|
parser.add_argument( |
|
'--scales', |
|
type=str, |
|
default='[[1,3,640,640],[1,3,640,640],[1,3,640,640]]', |
|
help='Input scales for build dynamic input shape engine') |
|
parser.add_argument( |
|
'--fp16', action='store_true', help='Build model with fp16 mode') |
|
args = parser.parse_args() |
|
args.img_size *= 2 if len(args.img_size) == 1 else 1 |
|
return args |
|
|
|
|
|
def main(args): |
|
img_size = (1, 3, *args.img_size) |
|
try: |
|
scales = eval(args.scales) |
|
except Exception: |
|
print('Input scales is not a python variable') |
|
print('Set scales default None') |
|
scales = None |
|
builder = EngineBuilder(args.checkpoint, img_size, args.device) |
|
builder.build(scales, fp16=args.fp16) |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
main(args) |
|
|