|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from unidecode import unidecode |
|
import tensorflow as tf |
|
import cloudpickle |
|
from transformers import DistilBertTokenizerFast |
|
import os |
|
|
|
def load_model(): |
|
interpreter = tf.lite.Interpreter(model_path=os.path.join("models/lang_detect_hf_distilbert.tflite")) |
|
with open("models/lang_detect_labelencoder.bin", "rb") as model_file_obj: |
|
label_encoder = cloudpickle.load(model_file_obj) |
|
|
|
model_checkpoint = "distilbert-base-multilingual-cased" |
|
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint) |
|
return interpreter, label_encoder, tokenizer |
|
|
|
interpreter, label_encoder, tokenizer = load_model() |
|
|
|
def inference(text): |
|
tflite_pred = "Can't Predict" |
|
if text != "": |
|
tokens = tokenizer(text, max_length=50, padding="max_length", truncation=True, return_tensors="tf") |
|
|
|
interpreter.allocate_tensors() |
|
input_details = interpreter.get_input_details() |
|
output_details = interpreter.get_output_details()[0] |
|
attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids'] |
|
interpreter.set_tensor(input_details[0]["index"], attention_mask) |
|
interpreter.set_tensor(input_details[1]["index"], input_ids) |
|
interpreter.invoke() |
|
tflite_pred = interpreter.get_tensor(output_details["index"])[0] |
|
tflite_pred_argmax = np.argmax(tflite_pred) |
|
tflite_pred = f"{label_encoder.inverse_transform([tflite_pred_argmax])[0].upper()} ({str(np.round(tflite_pred[tflite_pred_argmax], 3))})" |
|
return tflite_pred |
|
|
|
|
|
def main(): |
|
st.title("Language Detection") |
|
lang_trained = 'eng, rus, ita, tur, epo, ber, deu, kab, fra, por, spa, hun, jpn, heb, ukr, nld, fin, pol, mkd, lit, cmn, mar, ces, dan'.upper() |
|
st.write(f'Model is trained on the following languages \n{lang_trained}') |
|
review = st.text_area("Enter Text:", "", height=200) |
|
if st.button("Submit"): |
|
result = inference(review) |
|
st.write(result) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|