monai
medical
File size: 4,138 Bytes
a23faaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import os

import numpy as np
import onnx
import onnxruntime
import torch
from monai.networks.nets import SEResNet50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_model_and_export(modelname, outname, out_channels, height, width, multigpu=False, in_channels=3):
    """
    Loading a model by name.

    Args:
        modelname: a whole path name of the model that need to be loaded.
        outname: a name for output onnx model.
        out_channels: output channels, which usually equals to 1 + class_number.
        height: input images' height.
        width: input images' width.
        multigpu: if the pre-trained model trained on a multigpu environment.
        in_channels: input images' channel number.
    """
    isopen = os.path.exists(modelname)
    if not isopen:
        raise Exception("The specified model to load does not exist!")

    model = SEResNet50(spatial_dims=2, in_channels=in_channels, num_classes=out_channels)

    if multigpu:
        model = torch.nn.DataParallel(model)
    model = model.cuda()
    model.load_state_dict(torch.load(modelname, map_location=device))  # if the model is trained on multi gpu
    model = model.eval()

    np.random.seed(0)
    x = np.random.random((1, 3, width, height))
    x = torch.tensor(x, dtype=torch.float32)
    x = x.cuda()
    torch_out = model(x)
    input_names = ["INPUT__0"]
    output_names = ["OUTPUT__0"]
    # Export the model
    if multigpu:
        model_trans = model.module
    else:
        model_trans = model
    torch.onnx.export(
        model_trans,  # model to save
        x,  # model input
        outname,  # model save path
        export_params=True,
        verbose=True,
        do_constant_folding=True,
        input_names=input_names,
        output_names=output_names,
        opset_version=15,
        dynamic_axes={"INPUT__0": {0: "batch_size"}, "OUTPUT__0": {0: "batch_size"}},
    )
    onnx_model = onnx.load(outname)
    onnx.checker.check_model(onnx_model, full_check=True)
    ort_session = onnxruntime.InferenceSession(outname)

    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

    # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
    ort_outs = ort_session.run(["OUTPUT__0"], ort_inputs)
    numpy_torch_out = to_numpy(torch_out)
    # compare ONNX Runtime and PyTorch results
    np.testing.assert_allclose(numpy_torch_out, ort_outs[0], rtol=1e-03, atol=1e-05)
    print("Exported model has been tested with ONNXRuntime, and the result looks good!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # the original model for converting.
    parser.add_argument(
        "--model",
        type=str,
        default=r"/workspace/bundle/endoscopic_inbody_classification/models/model.pt",
        help="Input an existing model weight",
    )

    # path to save the onnx model.
    parser.add_argument(
        "--outpath",
        type=str,
        default=r"/workspace/bundle/endoscopic_inbody_classification/models/model.onnx",
        help="A path to save the onnx model.",
    )

    parser.add_argument("--width", type=int, default=256, help="Width for exporting onnx model.")

    parser.add_argument("--height", type=int, default=256, help="Height for exporting onnx model.")

    parser.add_argument(
        "--out_channels", type=int, default=2, help="Number of expected out_channels in model for exporting to onnx."
    )

    parser.add_argument("--multigpu", type=bool, default=False, help="If loading model trained with multi gpu.")

    args = parser.parse_args()
    modelname = args.model
    outname = args.outpath
    out_channels = args.out_channels
    height = args.height
    width = args.width
    multigpu = args.multigpu

    if os.path.exists(outname):
        raise Exception(
            "The specified outpath already exists! Change the outpath to avoid overwriting your saved model. "
        )
    model = load_model_and_export(modelname, outname, out_channels, height, width, multigpu)