|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import sys |
|
|
|
import torch |
|
from pytorch_lightning import Trainer |
|
|
|
import nemo |
|
from nemo.core import ModelPT |
|
from nemo.core.classes import Exportable |
|
from nemo.core.config.pytorch_lightning import TrainerConfig |
|
from nemo.utils import logging |
|
|
|
try: |
|
from contextlib import nullcontext |
|
except ImportError: |
|
|
|
from contextlib import suppress as nullcontext |
|
|
|
|
|
def get_args(argv): |
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=f"Export NeMo models to ONNX/Torchscript", |
|
) |
|
parser.add_argument("source", help="Source .nemo file") |
|
parser.add_argument("out", help="Location to write result to") |
|
parser.add_argument("--autocast", action="store_true", help="Use autocast when exporting") |
|
parser.add_argument("--runtime-check", action="store_true", help="Runtime check of exported net result") |
|
parser.add_argument("--verbose", default=None, help="Verbose level for logging, numeric") |
|
parser.add_argument("--max-batch", type=int, default=None, help="Max batch size for model export") |
|
parser.add_argument("--max-dim", type=int, default=None, help="Max dimension(s) for model export") |
|
parser.add_argument("--onnx-opset", type=int, default=None, help="ONNX opset for model export") |
|
parser.add_argument( |
|
"--cache_support", action="store_true", help="enables caching inputs for the models support it." |
|
) |
|
parser.add_argument("--device", default="cuda", help="Device to export for") |
|
parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification") |
|
parser.add_argument( |
|
"--export-config", |
|
metavar="KEY=VALUE", |
|
nargs='+', |
|
help="Set a number of key-value pairs to model.export_config dictionary " |
|
"(do not put spaces before or after the = sign). " |
|
"Note that values are always treated as strings.", |
|
) |
|
|
|
args = parser.parse_args(argv) |
|
return args |
|
|
|
|
|
def nemo_export(argv): |
|
args = get_args(argv) |
|
loglevel = logging.INFO |
|
|
|
|
|
|
|
if args.verbose is not None: |
|
numeric_level = getattr(logging, args.verbose.upper(), None) |
|
if not isinstance(numeric_level, int): |
|
raise ValueError('Invalid log level: %s' % numeric_level) |
|
loglevel = numeric_level |
|
logging.setLevel(loglevel) |
|
logging.info("Logging level set to {}".format(loglevel)) |
|
|
|
"""Convert a .nemo saved model into .riva Riva input format.""" |
|
nemo_in = args.source |
|
out = args.out |
|
|
|
|
|
cfg_trainer = TrainerConfig( |
|
accelerator='gpu', |
|
strategy="ddp", |
|
num_nodes=1, |
|
devices=1, |
|
|
|
logger=False, |
|
enable_checkpointing=False, |
|
) |
|
trainer = Trainer(cfg_trainer) |
|
|
|
logging.info("Restoring NeMo model from '{}'".format(nemo_in)) |
|
try: |
|
with torch.inference_mode(): |
|
|
|
model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer) |
|
except Exception as e: |
|
logging.error( |
|
"Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format( |
|
nemo_in |
|
) |
|
) |
|
raise e |
|
|
|
logging.info("Model {} restored from '{}'".format(model.__class__.__name__, nemo_in)) |
|
|
|
if not isinstance(model, Exportable): |
|
logging.error("Your NeMo model class ({}) is not Exportable.".format(model.__class__.__name__)) |
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
check_trace = args.runtime_check |
|
|
|
in_args = {} |
|
max_batch = 1 |
|
max_dim = None |
|
if args.max_batch is not None: |
|
in_args["max_batch"] = args.max_batch |
|
max_batch = args.max_batch |
|
if args.max_dim is not None: |
|
in_args["max_dim"] = args.max_dim |
|
max_dim = args.max_dim |
|
|
|
if args.cache_support: |
|
model.set_export_config({"cache_support": "True"}) |
|
|
|
if args.export_config: |
|
kv = {} |
|
for key_value in args.export_config: |
|
lst = key_value.split("=") |
|
if len(lst) != 2: |
|
raise Exception("Use correct format for --export_config: k=v") |
|
k, v = lst |
|
kv[k] = v |
|
model.set_export_config(kv) |
|
|
|
autocast = nullcontext |
|
if args.autocast: |
|
autocast = torch.cuda.amp.autocast |
|
try: |
|
with autocast(), torch.no_grad(), torch.inference_mode(): |
|
model.to(device=args.device).freeze() |
|
model.eval() |
|
input_example = None |
|
if check_trace and len(in_args) > 0: |
|
input_example = model.input_module.input_example(**in_args) |
|
check_trace = [input_example] |
|
for key, arg in in_args.items(): |
|
in_args[key] = (arg + 1) // 2 |
|
input_example2 = model.input_module.input_example(**in_args) |
|
check_trace.append(input_example2) |
|
logging.info(f"Using additional check args: {in_args}") |
|
|
|
_, descriptions = model.export( |
|
out, |
|
input_example=input_example, |
|
check_trace=check_trace, |
|
check_tolerance=args.check_tolerance, |
|
onnx_opset_version=args.onnx_opset, |
|
verbose=bool(args.verbose), |
|
) |
|
|
|
except Exception as e: |
|
logging.error( |
|
"Export failed. Please make sure your NeMo model class ({}) has working export() and that you have the latest NeMo package installed with [all] dependencies.".format( |
|
model.__class__ |
|
) |
|
) |
|
raise e |
|
|
|
|
|
if __name__ == '__main__': |
|
nemo_export(sys.argv[1:]) |
|
|