crystal-technologies's picture
Upload 1287 files
2d8da09
raw
history blame
7.2 kB
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from concurrent import futures
import api.nmt_pb2 as nmt
import api.nmt_pb2_grpc as nmtsrv
import grpc
import torch
import nemo.collections.nlp as nemo_nlp
from nemo.utils import logging
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir", required=True, type=str, help="Path to a folder containing .nemo translation model files.",
)
parser.add_argument(
"--punctuation_model",
default="",
type=str,
help="Optionally provide a path a .nemo file for punctation and capitalization (recommend if working with Riva speech recognition outputs)",
)
parser.add_argument("--port", default=50052, type=int, required=False)
parser.add_argument("--batch_size", type=int, default=256, help="Maximum number of batches to process")
parser.add_argument("--beam_size", type=int, default=1, help="Beam Size")
parser.add_argument("--len_pen", type=float, default=0.6, help="Length Penalty")
parser.add_argument("--max_delta_length", type=int, default=5, help="Max Delta Generation Length.")
args = parser.parse_args()
return args
def batches(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
class RivaTranslateServicer(nmtsrv.RivaTranslateServicer):
"""Provides methods that implement functionality of route guide server."""
def __init__(
self, model_dir, punctuation_model_path, beam_size=1, len_pen=0.6, max_delta_length=5, batch_size=256,
):
self._models = {}
self._beam_size = beam_size
self._len_pen = len_pen
self._max_delta_length = max_delta_length
self._batch_size = batch_size
self._punctuation_model_path = punctuation_model_path
self._model_dir = model_dir
model_paths = [os.path.join(model_dir, fname) for fname in os.listdir(model_dir) if fname.endswith('.nemo')]
for idx, model_path in enumerate(model_paths):
assert os.path.exists(model_path)
logging.info(f"Loading model {model_path}")
self._load_model(model_path)
if self._punctuation_model_path != "":
assert os.path.exists(punctuation_model_path)
logging.info(f"Loading punctuation model {model_path}")
self._load_puncutation_model(punctuation_model_path)
logging.info("Models loaded. Ready for inference requests.")
def _load_puncutation_model(self, punctuation_model_path):
if punctuation_model_path.endswith(".nemo"):
self.punctuation_model = nemo_nlp.models.PunctuationCapitalizationModel.restore_from(
restore_path=punctuation_model_path
)
self.punctuation_model.eval()
else:
raise NotImplemented(f"Only support .nemo files, but got: {punctuation_model_path}")
if torch.cuda.is_available():
self.punctuation_model = self.punctuation_model.cuda()
def _load_model(self, model_path):
if model_path.endswith(".nemo"):
logging.info("Attempting to initialize from .nemo file")
model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=model_path)
model = model.eval()
model.beam_search.beam_size = self._beam_size
model.beam_search.len_pen = self._len_pen
model.beam_search.max_delta_length = self._max_delta_length
if torch.cuda.is_available():
model = model.cuda()
else:
raise NotImplemented(f"Only support .nemo files, but got: {model_path}")
if not hasattr(model, "src_language") or not hasattr(model, "tgt_language"):
raise ValueError(
f"Could not find src_language and tgt_language in model attributes. If using NeMo rc1 checkpoints, please edit the config files to add model.src_language and model.tgt_language"
)
src_language = model.src_language
tgt_language = model.tgt_language
if src_language not in self._models:
self._models[src_language] = {}
if tgt_language not in self._models[src_language]:
self._models[src_language][tgt_language] = model
if torch.cuda.is_available():
self._models[src_language][tgt_language] = self._models[src_language][tgt_language].cuda()
else:
raise ValueError(f"Already found model for language pair {src_language}-{tgt_language}")
def TranslateText(self, request, context):
logging.info(f"Request received w/ {len(request.texts)} utterances")
results = []
if request.source_language not in self._models:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(
f"Could not find source-target language pair {request.source_language}-{request.target_language} in list of models."
)
return nmt.TranslateTextResponse()
if request.target_language not in self._models[request.source_language]:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(
f"Could not find source-target language pair {request.source_language}-{request.target_language} in list of models."
)
return nmt.TranslateTextResponse()
request_strings = [x for x in request.texts]
for batch in batches(request_strings, self._batch_size):
if self._punctuation_model_path != "":
batch = self.punctuation_model.add_punctuation_capitalization(batch)
batch_results = self._models[request.source_language][request.target_language].translate(text=batch)
translations = [nmt.Translation(translation=x) for x in batch_results]
results.extend(translations)
return nmt.TranslateTextResponse(translations=results)
def serve():
args = get_args()
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
servicer = RivaTranslateServicer(
model_dir=args.model_dir,
punctuation_model_path=args.punctuation_model,
beam_size=args.beam_size,
len_pen=args.len_pen,
batch_size=args.batch_size,
max_delta_length=args.max_delta_length,
)
nmtsrv.add_RivaTranslateServicer_to_server(servicer, server)
server.add_insecure_port('[::]:' + str(args.port))
server.start()
server.wait_for_termination()
if __name__ == '__main__':
serve()