|
import warnings |
|
from collections import namedtuple |
|
from functools import partial |
|
from pathlib import Path |
|
from typing import List, Optional, Union |
|
|
|
import numpy as np |
|
import onnxruntime |
|
|
|
try: |
|
import tensorrt as trt |
|
except Exception: |
|
trt = None |
|
import torch |
|
|
|
warnings.filterwarnings(action='ignore', category=DeprecationWarning) |
|
|
|
|
|
class TRTWrapper(torch.nn.Module): |
|
dtype_mapping = {} |
|
|
|
def __init__(self, weight: Union[str, Path], |
|
device: Optional[torch.device]): |
|
super().__init__() |
|
weight = Path(weight) if isinstance(weight, str) else weight |
|
assert weight.exists() and weight.suffix in ('.engine', '.plan') |
|
if isinstance(device, str): |
|
device = torch.device(device) |
|
elif isinstance(device, int): |
|
device = torch.device(f'cuda:{device}') |
|
self.weight = weight |
|
self.device = device |
|
self.stream = torch.cuda.Stream(device=device) |
|
self.__update_mapping() |
|
self.__init_engine() |
|
self.__init_bindings() |
|
|
|
def __update_mapping(self): |
|
self.dtype_mapping.update({ |
|
trt.bool: torch.bool, |
|
trt.int8: torch.int8, |
|
trt.int32: torch.int32, |
|
trt.float16: torch.float16, |
|
trt.float32: torch.float32 |
|
}) |
|
|
|
def __init_engine(self): |
|
logger = trt.Logger(trt.Logger.ERROR) |
|
self.log = partial(logger.log, trt.Logger.ERROR) |
|
trt.init_libnvinfer_plugins(logger, namespace='') |
|
self.logger = logger |
|
with trt.Runtime(logger) as runtime: |
|
model = runtime.deserialize_cuda_engine(self.weight.read_bytes()) |
|
|
|
context = model.create_execution_context() |
|
|
|
names = [model.get_binding_name(i) for i in range(model.num_bindings)] |
|
|
|
num_inputs, num_outputs = 0, 0 |
|
|
|
for i in range(model.num_bindings): |
|
if model.binding_is_input(i): |
|
num_inputs += 1 |
|
else: |
|
num_outputs += 1 |
|
|
|
self.is_dynamic = -1 in model.get_binding_shape(0) |
|
|
|
self.model = model |
|
self.context = context |
|
self.input_names = names[:num_inputs] |
|
self.output_names = names[num_inputs:] |
|
self.num_inputs = num_inputs |
|
self.num_outputs = num_outputs |
|
self.num_bindings = num_inputs + num_outputs |
|
self.bindings: List[int] = [0] * self.num_bindings |
|
|
|
def __init_bindings(self): |
|
Binding = namedtuple('Binding', ('name', 'dtype', 'shape')) |
|
inputs_info = [] |
|
outputs_info = [] |
|
|
|
for i, name in enumerate(self.input_names): |
|
assert self.model.get_binding_name(i) == name |
|
dtype = self.dtype_mapping[self.model.get_binding_dtype(i)] |
|
shape = tuple(self.model.get_binding_shape(i)) |
|
inputs_info.append(Binding(name, dtype, shape)) |
|
|
|
for i, name in enumerate(self.output_names): |
|
i += self.num_inputs |
|
assert self.model.get_binding_name(i) == name |
|
dtype = self.dtype_mapping[self.model.get_binding_dtype(i)] |
|
shape = tuple(self.model.get_binding_shape(i)) |
|
outputs_info.append(Binding(name, dtype, shape)) |
|
self.inputs_info = inputs_info |
|
self.outputs_info = outputs_info |
|
if not self.is_dynamic: |
|
self.output_tensor = [ |
|
torch.empty(o.shape, dtype=o.dtype, device=self.device) |
|
for o in outputs_info |
|
] |
|
|
|
def forward(self, *inputs): |
|
|
|
assert len(inputs) == self.num_inputs |
|
|
|
contiguous_inputs: List[torch.Tensor] = [ |
|
i.contiguous() for i in inputs |
|
] |
|
|
|
for i in range(self.num_inputs): |
|
self.bindings[i] = contiguous_inputs[i].data_ptr() |
|
if self.is_dynamic: |
|
self.context.set_binding_shape( |
|
i, tuple(contiguous_inputs[i].shape)) |
|
|
|
|
|
outputs: List[torch.Tensor] = [] |
|
|
|
for i in range(self.num_outputs): |
|
j = i + self.num_inputs |
|
if self.is_dynamic: |
|
shape = tuple(self.context.get_binding_shape(j)) |
|
output = torch.empty( |
|
size=shape, |
|
dtype=self.output_dtypes[i], |
|
device=self.device) |
|
|
|
else: |
|
output = self.output_tensor[i] |
|
outputs.append(output) |
|
self.bindings[j] = output.data_ptr() |
|
|
|
self.context.execute_async_v2(self.bindings, self.stream.cuda_stream) |
|
self.stream.synchronize() |
|
|
|
return tuple(outputs) |
|
|
|
|
|
class ORTWrapper(torch.nn.Module): |
|
|
|
def __init__(self, weight: Union[str, Path], |
|
device: Optional[torch.device]): |
|
super().__init__() |
|
weight = Path(weight) if isinstance(weight, str) else weight |
|
assert weight.exists() and weight.suffix == '.onnx' |
|
|
|
if isinstance(device, str): |
|
device = torch.device(device) |
|
elif isinstance(device, int): |
|
device = torch.device(f'cuda:{device}') |
|
self.weight = weight |
|
self.device = device |
|
self.__init_session() |
|
self.__init_bindings() |
|
|
|
def __init_session(self): |
|
providers = ['CPUExecutionProvider'] |
|
if 'cuda' in self.device.type: |
|
providers.insert(0, 'CUDAExecutionProvider') |
|
|
|
session = onnxruntime.InferenceSession( |
|
str(self.weight), providers=providers) |
|
self.session = session |
|
|
|
def __init_bindings(self): |
|
Binding = namedtuple('Binding', ('name', 'dtype', 'shape')) |
|
inputs_info = [] |
|
outputs_info = [] |
|
self.is_dynamic = False |
|
for i, tensor in enumerate(self.session.get_inputs()): |
|
if any(not isinstance(i, int) for i in tensor.shape): |
|
self.is_dynamic = True |
|
inputs_info.append( |
|
Binding(tensor.name, tensor.type, tuple(tensor.shape))) |
|
|
|
for i, tensor in enumerate(self.session.get_outputs()): |
|
outputs_info.append( |
|
Binding(tensor.name, tensor.type, tuple(tensor.shape))) |
|
self.inputs_info = inputs_info |
|
self.outputs_info = outputs_info |
|
self.num_inputs = len(inputs_info) |
|
|
|
def forward(self, *inputs): |
|
|
|
assert len(inputs) == self.num_inputs |
|
|
|
contiguous_inputs: List[np.ndarray] = [ |
|
i.contiguous().cpu().numpy() for i in inputs |
|
] |
|
|
|
if not self.is_dynamic: |
|
|
|
for i in range(self.num_inputs): |
|
assert contiguous_inputs[i].shape == self.inputs_info[i].shape |
|
|
|
outputs = self.session.run([o.name for o in self.outputs_info], { |
|
j.name: contiguous_inputs[i] |
|
for i, j in enumerate(self.inputs_info) |
|
}) |
|
|
|
return tuple(torch.from_numpy(o).to(self.device) for o in outputs) |
|
|