File size: 7,331 Bytes
2d8da09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import sys
import torch
from pytorch_lightning import Trainer
import nemo
from nemo.core import ModelPT
from nemo.core.classes import Exportable
from nemo.core.config.pytorch_lightning import TrainerConfig
from nemo.utils import logging
try:
from contextlib import nullcontext
except ImportError:
# handle python < 3.7
from contextlib import suppress as nullcontext
def get_args(argv):
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=f"Export NeMo models to ONNX/Torchscript",
)
parser.add_argument("source", help="Source .nemo file")
parser.add_argument("out", help="Location to write result to")
parser.add_argument("--autocast", action="store_true", help="Use autocast when exporting")
parser.add_argument("--runtime-check", action="store_true", help="Runtime check of exported net result")
parser.add_argument("--verbose", default=None, help="Verbose level for logging, numeric")
parser.add_argument("--max-batch", type=int, default=None, help="Max batch size for model export")
parser.add_argument("--max-dim", type=int, default=None, help="Max dimension(s) for model export")
parser.add_argument("--onnx-opset", type=int, default=None, help="ONNX opset for model export")
parser.add_argument(
"--cache_support", action="store_true", help="enables caching inputs for the models support it."
)
parser.add_argument("--device", default="cuda", help="Device to export for")
parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification")
parser.add_argument(
"--export-config",
metavar="KEY=VALUE",
nargs='+',
help="Set a number of key-value pairs to model.export_config dictionary "
"(do not put spaces before or after the = sign). "
"Note that values are always treated as strings.",
)
args = parser.parse_args(argv)
return args
def nemo_export(argv):
args = get_args(argv)
loglevel = logging.INFO
# assuming loglevel is bound to the string value obtained from the
# command line argument. Convert to upper case to allow the user to
# specify --log=DEBUG or --log=debug
if args.verbose is not None:
numeric_level = getattr(logging, args.verbose.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError('Invalid log level: %s' % numeric_level)
loglevel = numeric_level
logging.setLevel(loglevel)
logging.info("Logging level set to {}".format(loglevel))
"""Convert a .nemo saved model into .riva Riva input format."""
nemo_in = args.source
out = args.out
# Create a PL trainer object which is required for restoring Megatron models
cfg_trainer = TrainerConfig(
accelerator='gpu',
strategy="ddp",
num_nodes=1,
devices=1,
# Need to set the following two to False as ExpManager will take care of them differently.
logger=False,
enable_checkpointing=False,
)
trainer = Trainer(cfg_trainer)
logging.info("Restoring NeMo model from '{}'".format(nemo_in))
try:
with torch.inference_mode():
# Restore instance from .nemo file using generic model restore_from
model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer)
except Exception as e:
logging.error(
"Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format(
nemo_in
)
)
raise e
logging.info("Model {} restored from '{}'".format(model.__class__.__name__, nemo_in))
if not isinstance(model, Exportable):
logging.error("Your NeMo model class ({}) is not Exportable.".format(model.__class__.__name__))
sys.exit(1)
#
# Add custom export parameters here
#
check_trace = args.runtime_check
in_args = {}
max_batch = 1
max_dim = None
if args.max_batch is not None:
in_args["max_batch"] = args.max_batch
max_batch = args.max_batch
if args.max_dim is not None:
in_args["max_dim"] = args.max_dim
max_dim = args.max_dim
if args.cache_support:
model.set_export_config({"cache_support": "True"})
if args.export_config:
kv = {}
for key_value in args.export_config:
lst = key_value.split("=")
if len(lst) != 2:
raise Exception("Use correct format for --export_config: k=v")
k, v = lst
kv[k] = v
model.set_export_config(kv)
autocast = nullcontext
if args.autocast:
autocast = torch.cuda.amp.autocast
try:
with autocast(), torch.no_grad(), torch.inference_mode():
model.to(device=args.device).freeze()
model.eval()
input_example = None
if check_trace and len(in_args) > 0:
input_example = model.input_module.input_example(**in_args)
check_trace = [input_example]
for key, arg in in_args.items():
in_args[key] = (arg + 1) // 2
input_example2 = model.input_module.input_example(**in_args)
check_trace.append(input_example2)
logging.info(f"Using additional check args: {in_args}")
_, descriptions = model.export(
out,
input_example=input_example,
check_trace=check_trace,
check_tolerance=args.check_tolerance,
onnx_opset_version=args.onnx_opset,
verbose=bool(args.verbose),
)
except Exception as e:
logging.error(
"Export failed. Please make sure your NeMo model class ({}) has working export() and that you have the latest NeMo package installed with [all] dependencies.".format(
model.__class__
)
)
raise e
if __name__ == '__main__':
nemo_export(sys.argv[1:])
|