|
import argparse |
|
import os |
|
|
|
import numpy as np |
|
import onnx |
|
import onnxruntime |
|
import torch |
|
from monai.networks.nets import SEResNet50 |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def load_model_and_export(modelname, outname, out_channels, height, width, multigpu=False, in_channels=3): |
|
""" |
|
Loading a model by name. |
|
|
|
Args: |
|
modelname: a whole path name of the model that need to be loaded. |
|
outname: a name for output onnx model. |
|
out_channels: output channels, which usually equals to 1 + class_number. |
|
height: input images' height. |
|
width: input images' width. |
|
multigpu: if the pre-trained model trained on a multigpu environment. |
|
in_channels: input images' channel number. |
|
""" |
|
isopen = os.path.exists(modelname) |
|
if not isopen: |
|
raise Exception("The specified model to load does not exist!") |
|
|
|
model = SEResNet50(spatial_dims=2, in_channels=in_channels, num_classes=out_channels) |
|
|
|
if multigpu: |
|
model = torch.nn.DataParallel(model) |
|
model = model.cuda() |
|
model.load_state_dict(torch.load(modelname, map_location=device)) |
|
model = model.eval() |
|
|
|
np.random.seed(0) |
|
x = np.random.random((1, 3, width, height)) |
|
x = torch.tensor(x, dtype=torch.float32) |
|
x = x.cuda() |
|
torch_out = model(x) |
|
input_names = ["INPUT__0"] |
|
output_names = ["OUTPUT__0"] |
|
|
|
if multigpu: |
|
model_trans = model.module |
|
else: |
|
model_trans = model |
|
torch.onnx.export( |
|
model_trans, |
|
x, |
|
outname, |
|
export_params=True, |
|
verbose=True, |
|
do_constant_folding=True, |
|
input_names=input_names, |
|
output_names=output_names, |
|
opset_version=15, |
|
dynamic_axes={"INPUT__0": {0: "batch_size"}, "OUTPUT__0": {0: "batch_size"}}, |
|
) |
|
onnx_model = onnx.load(outname) |
|
onnx.checker.check_model(onnx_model, full_check=True) |
|
ort_session = onnxruntime.InferenceSession(outname) |
|
|
|
def to_numpy(tensor): |
|
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() |
|
|
|
|
|
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} |
|
ort_outs = ort_session.run(["OUTPUT__0"], ort_inputs) |
|
numpy_torch_out = to_numpy(torch_out) |
|
|
|
np.testing.assert_allclose(numpy_torch_out, ort_outs[0], rtol=1e-03, atol=1e-05) |
|
print("Exported model has been tested with ONNXRuntime, and the result looks good!") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--model", |
|
type=str, |
|
default=r"/workspace/bundle/endoscopic_inbody_classification/models/model.pt", |
|
help="Input an existing model weight", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--outpath", |
|
type=str, |
|
default=r"/workspace/bundle/endoscopic_inbody_classification/models/model.onnx", |
|
help="A path to save the onnx model.", |
|
) |
|
|
|
parser.add_argument("--width", type=int, default=256, help="Width for exporting onnx model.") |
|
|
|
parser.add_argument("--height", type=int, default=256, help="Height for exporting onnx model.") |
|
|
|
parser.add_argument( |
|
"--out_channels", type=int, default=2, help="Number of expected out_channels in model for exporting to onnx." |
|
) |
|
|
|
parser.add_argument("--multigpu", type=bool, default=False, help="If loading model trained with multi gpu.") |
|
|
|
args = parser.parse_args() |
|
modelname = args.model |
|
outname = args.outpath |
|
out_channels = args.out_channels |
|
height = args.height |
|
width = args.width |
|
multigpu = args.multigpu |
|
|
|
if os.path.exists(outname): |
|
raise Exception( |
|
"The specified outpath already exists! Change the outpath to avoid overwriting your saved model. " |
|
) |
|
model = load_model_and_export(modelname, outname, out_channels, height, width, multigpu) |
|
|