crystal-technologies's picture
Upload 1287 files
2d8da09
raw
history blame
7.33 kB
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import sys
import torch
from pytorch_lightning import Trainer
import nemo
from nemo.core import ModelPT
from nemo.core.classes import Exportable
from nemo.core.config.pytorch_lightning import TrainerConfig
from nemo.utils import logging
try:
from contextlib import nullcontext
except ImportError:
# handle python < 3.7
from contextlib import suppress as nullcontext
def get_args(argv):
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=f"Export NeMo models to ONNX/Torchscript",
)
parser.add_argument("source", help="Source .nemo file")
parser.add_argument("out", help="Location to write result to")
parser.add_argument("--autocast", action="store_true", help="Use autocast when exporting")
parser.add_argument("--runtime-check", action="store_true", help="Runtime check of exported net result")
parser.add_argument("--verbose", default=None, help="Verbose level for logging, numeric")
parser.add_argument("--max-batch", type=int, default=None, help="Max batch size for model export")
parser.add_argument("--max-dim", type=int, default=None, help="Max dimension(s) for model export")
parser.add_argument("--onnx-opset", type=int, default=None, help="ONNX opset for model export")
parser.add_argument(
"--cache_support", action="store_true", help="enables caching inputs for the models support it."
)
parser.add_argument("--device", default="cuda", help="Device to export for")
parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification")
parser.add_argument(
"--export-config",
metavar="KEY=VALUE",
nargs='+',
help="Set a number of key-value pairs to model.export_config dictionary "
"(do not put spaces before or after the = sign). "
"Note that values are always treated as strings.",
)
args = parser.parse_args(argv)
return args
def nemo_export(argv):
args = get_args(argv)
loglevel = logging.INFO
# assuming loglevel is bound to the string value obtained from the
# command line argument. Convert to upper case to allow the user to
# specify --log=DEBUG or --log=debug
if args.verbose is not None:
numeric_level = getattr(logging, args.verbose.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError('Invalid log level: %s' % numeric_level)
loglevel = numeric_level
logging.setLevel(loglevel)
logging.info("Logging level set to {}".format(loglevel))
"""Convert a .nemo saved model into .riva Riva input format."""
nemo_in = args.source
out = args.out
# Create a PL trainer object which is required for restoring Megatron models
cfg_trainer = TrainerConfig(
accelerator='gpu',
strategy="ddp",
num_nodes=1,
devices=1,
# Need to set the following two to False as ExpManager will take care of them differently.
logger=False,
enable_checkpointing=False,
)
trainer = Trainer(cfg_trainer)
logging.info("Restoring NeMo model from '{}'".format(nemo_in))
try:
with torch.inference_mode():
# Restore instance from .nemo file using generic model restore_from
model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer)
except Exception as e:
logging.error(
"Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format(
nemo_in
)
)
raise e
logging.info("Model {} restored from '{}'".format(model.__class__.__name__, nemo_in))
if not isinstance(model, Exportable):
logging.error("Your NeMo model class ({}) is not Exportable.".format(model.__class__.__name__))
sys.exit(1)
#
# Add custom export parameters here
#
check_trace = args.runtime_check
in_args = {}
max_batch = 1
max_dim = None
if args.max_batch is not None:
in_args["max_batch"] = args.max_batch
max_batch = args.max_batch
if args.max_dim is not None:
in_args["max_dim"] = args.max_dim
max_dim = args.max_dim
if args.cache_support:
model.set_export_config({"cache_support": "True"})
if args.export_config:
kv = {}
for key_value in args.export_config:
lst = key_value.split("=")
if len(lst) != 2:
raise Exception("Use correct format for --export_config: k=v")
k, v = lst
kv[k] = v
model.set_export_config(kv)
autocast = nullcontext
if args.autocast:
autocast = torch.cuda.amp.autocast
try:
with autocast(), torch.no_grad(), torch.inference_mode():
model.to(device=args.device).freeze()
model.eval()
input_example = None
if check_trace and len(in_args) > 0:
input_example = model.input_module.input_example(**in_args)
check_trace = [input_example]
for key, arg in in_args.items():
in_args[key] = (arg + 1) // 2
input_example2 = model.input_module.input_example(**in_args)
check_trace.append(input_example2)
logging.info(f"Using additional check args: {in_args}")
_, descriptions = model.export(
out,
input_example=input_example,
check_trace=check_trace,
check_tolerance=args.check_tolerance,
onnx_opset_version=args.onnx_opset,
verbose=bool(args.verbose),
)
except Exception as e:
logging.error(
"Export failed. Please make sure your NeMo model class ({}) has working export() and that you have the latest NeMo package installed with [all] dependencies.".format(
model.__class__
)
)
raise e
if __name__ == '__main__':
nemo_export(sys.argv[1:])