File size: 5,007 Bytes
186701e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import argparse
from pathlib import Path
from typing import List, Optional, Tuple, Union
try:
import tensorrt as trt
except Exception:
trt = None
import warnings
import numpy as np
import torch
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
class EngineBuilder:
def __init__(
self,
checkpoint: Union[str, Path],
opt_shape: Union[Tuple, List] = (1, 3, 640, 640),
device: Optional[Union[str, int, torch.device]] = None) -> None:
checkpoint = Path(checkpoint) if isinstance(checkpoint,
str) else checkpoint
assert checkpoint.exists() and checkpoint.suffix == '.onnx'
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.checkpoint = checkpoint
self.opt_shape = np.array(opt_shape, dtype=np.float32)
self.device = device
def __build_engine(self,
scale: Optional[List[List]] = None,
fp16: bool = True,
with_profiling: bool = True) -> None:
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, namespace='')
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = torch.cuda.get_device_properties(
self.device).total_memory
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(self.checkpoint)):
raise RuntimeError(
f'failed to load ONNX file: {str(self.checkpoint)}')
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
profile = None
dshape = -1 in network.get_input(0).shape
if dshape:
profile = builder.create_optimization_profile()
if scale is None:
scale = np.array(
[[1, 1, 0.5, 0.5], [1, 1, 1, 1], [4, 1, 1.5, 1.5]],
dtype=np.float32)
scale = (self.opt_shape * scale).astype(np.int32)
elif isinstance(scale, List):
scale = np.array(scale, dtype=np.int32)
assert scale.shape[0] == 3, 'Input a wrong scale list'
else:
raise NotImplementedError
for inp in inputs:
logger.log(
trt.Logger.WARNING,
f'input "{inp.name}" with shape{inp.shape} {inp.dtype}')
if dshape:
profile.set_shape(inp.name, *scale)
for out in outputs:
logger.log(
trt.Logger.WARNING,
f'output "{out.name}" with shape{out.shape} {out.dtype}')
if fp16 and builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
self.weight = self.checkpoint.with_suffix('.engine')
if dshape:
config.add_optimization_profile(profile)
if with_profiling:
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
with builder.build_engine(network, config) as engine:
self.weight.write_bytes(engine.serialize())
logger.log(
trt.Logger.WARNING, f'Build tensorrt engine finish.\n'
f'Save in {str(self.weight.absolute())}')
def build(self,
scale: Optional[List[List]] = None,
fp16: bool = True,
with_profiling=True):
self.__build_engine(scale, fp16, with_profiling)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--img-size',
nargs='+',
type=int,
default=[640, 640],
help='Image size of height and width')
parser.add_argument(
'--device', type=str, default='cuda:0', help='TensorRT builder device')
parser.add_argument(
'--scales',
type=str,
default='[[1,3,640,640],[1,3,640,640],[1,3,640,640]]',
help='Input scales for build dynamic input shape engine')
parser.add_argument(
'--fp16', action='store_true', help='Build model with fp16 mode')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1
return args
def main(args):
img_size = (1, 3, *args.img_size)
try:
scales = eval(args.scales)
except Exception:
print('Input scales is not a python variable')
print('Set scales default None')
scales = None
builder = EngineBuilder(args.checkpoint, img_size, args.device)
builder.build(scales, fp16=args.fp16)
if __name__ == '__main__':
args = parse_args()
main(args)
|