File size: 2,369 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import logging
import os
import onnx
import tensorrt as trt
from typing import List
from collections import OrderedDict
from onnx import shape_inference


def vit_tagging_t2t(input_path="simple_model.onnx",output_path="vit.trt"):
    model = onnx.load(input_path)
    inferred_model = shape_inference.infer_shapes(model)
    #print(inferred_model.graph.value_info)
    simplified_model = input_path
    bitmask = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    
    trt_logger = trt.Logger()
    all_count,mix_count=0,0
    with trt.Builder(trt_logger) as builder, builder.create_network(bitmask) as network, builder.create_builder_config() as config, trt.OnnxParser(network, trt_logger) as parser:
        #config.max_workspace_size = self.max_workspace_size
        config.set_flag(trt.BuilderFlag.FP16)
        with open(simplified_model, 'rb') as f:
            success = parser.parse(f.read())
            if not success:
                for idx in range(parser.num_errors):
                    print(parser.get_error(idx))
                raise RuntimeError("Failed to parse the ONNX file.")
            profile = builder.create_optimization_profile()
            min_shape = [3,224,224]
            max_shape = [3,224,224]
            opt_shape = max_shape  #opt shape=max shape by default
            profile.set_shape("input",
                            min=(1, *min_shape),
                            opt=(70, *opt_shape),
                            max=(70, *max_shape))

            config.add_optimization_profile(profile)
        """
        for i in range(network.num_layers):
            all_count+=1
            layer = network.get_layer(i)
            if "ReduceMean" in layer.name or "Pow" in layer.name:
                mix_count+=1
                config.set_flag(trt.BuilderFlag.STRICT_TYPES)
                layer.precision = trt.float32
                layer.set_output_type(0, trt.float32)
        """
        #networtgetInput(0)->setType(DataType::kHALF)
        network.get_input(0).dtype = trt.float32
        network.get_output(0).dtype = trt.float32

        print(all_count,mix_count) 
        engine = builder.build_engine(network, config)
        #print(engine)
        with open(output_path, 'wb') as f:
            f.write(engine.serialize())
            f.close()
    
if __name__=="__main__":
    vit_tagging_t2t()