File size: 2,369 Bytes
bbfa6f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import logging
import os
import onnx
import tensorrt as trt
from typing import List
from collections import OrderedDict
from onnx import shape_inference
def vit_tagging_t2t(input_path="simple_model.onnx",output_path="vit.trt"):
model = onnx.load(input_path)
inferred_model = shape_inference.infer_shapes(model)
#print(inferred_model.graph.value_info)
simplified_model = input_path
bitmask = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
trt_logger = trt.Logger()
all_count,mix_count=0,0
with trt.Builder(trt_logger) as builder, builder.create_network(bitmask) as network, builder.create_builder_config() as config, trt.OnnxParser(network, trt_logger) as parser:
#config.max_workspace_size = self.max_workspace_size
config.set_flag(trt.BuilderFlag.FP16)
with open(simplified_model, 'rb') as f:
success = parser.parse(f.read())
if not success:
for idx in range(parser.num_errors):
print(parser.get_error(idx))
raise RuntimeError("Failed to parse the ONNX file.")
profile = builder.create_optimization_profile()
min_shape = [3,224,224]
max_shape = [3,224,224]
opt_shape = max_shape #opt shape=max shape by default
profile.set_shape("input",
min=(1, *min_shape),
opt=(70, *opt_shape),
max=(70, *max_shape))
config.add_optimization_profile(profile)
"""
for i in range(network.num_layers):
all_count+=1
layer = network.get_layer(i)
if "ReduceMean" in layer.name or "Pow" in layer.name:
mix_count+=1
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
layer.precision = trt.float32
layer.set_output_type(0, trt.float32)
"""
#networtgetInput(0)->setType(DataType::kHALF)
network.get_input(0).dtype = trt.float32
network.get_output(0).dtype = trt.float32
print(all_count,mix_count)
engine = builder.build_engine(network, config)
#print(engine)
with open(output_path, 'wb') as f:
f.write(engine.serialize())
f.close()
if __name__=="__main__":
vit_tagging_t2t() |