|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import time |
|
|
|
import flask |
|
import torch |
|
from flask import Flask, json, request |
|
from flask_cors import CORS |
|
|
|
import nemo.collections.nlp as nemo_nlp |
|
from nemo.utils import logging |
|
|
|
MODELS_DICT = {} |
|
|
|
model = None |
|
api = Flask(__name__) |
|
CORS(api) |
|
|
|
|
|
def initialize(config_file_path: str): |
|
""" |
|
Loads 'language-pair to NMT model mapping' |
|
""" |
|
__MODELS_DICT = None |
|
|
|
logging.info("Starting NMT service") |
|
logging.info(f"I will attempt to load all the models listed in {config_file_path}.") |
|
logging.info(f"Edit {config_file_path} to disable models you don't need.") |
|
if torch.cuda.is_available(): |
|
logging.info("CUDA is available. Running on GPU") |
|
else: |
|
logging.info("CUDA is not available. Defaulting to CPUs") |
|
|
|
|
|
with open(config_file_path) as f: |
|
__MODELS_DICT = json.load(f) |
|
|
|
if __MODELS_DICT is not None: |
|
for key, value in __MODELS_DICT.items(): |
|
logging.info(f"Loading model for {key} from file: {value}") |
|
if value.startswith("NGC/"): |
|
model = nemo_nlp.models.machine_translation.MTEncDecModel.from_pretrained(model_name=value[4:]) |
|
else: |
|
model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=value) |
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
MODELS_DICT[key] = model |
|
else: |
|
raise ValueError("Did not find the config.json or it was empty") |
|
logging.info("NMT service started") |
|
|
|
|
|
@api.route('/translate', methods=['GET', 'POST', 'OPTIONS']) |
|
def get_translation(): |
|
try: |
|
time_s = time.time() |
|
langpair = request.args["langpair"] |
|
src = request.args["text"] |
|
do_moses = request.args.get('do_moses', False) |
|
if langpair in MODELS_DICT: |
|
if do_moses: |
|
result = MODELS_DICT[langpair].translate( |
|
[src], source_lang=langpair.split('-')[0], target_lang=langpair.split('-')[1] |
|
) |
|
else: |
|
result = MODELS_DICT[langpair].translate([src]) |
|
|
|
duration = time.time() - time_s |
|
logging.info( |
|
f"Translated in {duration}. Input was: {request.args['text']} <############> Translation was: {result[0]}" |
|
) |
|
res = {'translation': result[0]} |
|
response = flask.jsonify(res) |
|
response.headers.add('Access-Control-Allow-Origin', '*') |
|
return response |
|
|
|
else: |
|
logging.error(f"Got the following langpair: {langpair} which was not found") |
|
except Exception as ex: |
|
res = {'translation': str(ex)} |
|
response = flask.jsonify(res) |
|
response.headers.add('Access-Control-Allow-Origin', '*') |
|
return res |
|
|
|
|
|
if __name__ == '__main__': |
|
initialize('config.json') |
|
api.run(host='0.0.0.0') |
|
|