|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from unidecode import unidecode |
|
import tensorflow as tf |
|
import cloudpickle |
|
from transformers import DistilBertTokenizerFast |
|
import os |
|
|
|
def load_model(): |
|
interpreter = tf.lite.Interpreter(model_path=os.path.join("models/dbpedia_classifier_hf_distilbert_l3.tflite")) |
|
with open("models/preprocessor_labelencoder_l3.bin", "rb") as model_file_obj: |
|
text_preprocessor, label_encoder = cloudpickle.load(model_file_obj) |
|
|
|
with open("models/label_map_l3.bin", "rb") as model_file_obj: |
|
label_map = cloudpickle.load(model_file_obj) |
|
|
|
model_checkpoint = "distilbert-base-uncased" |
|
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint) |
|
return interpreter, text_preprocessor, label_encoder, tokenizer, label_map |
|
|
|
interpreter, text_preprocessor, label_encoder, tokenizer, label_map = load_model() |
|
|
|
def inference(text): |
|
tflite_pred = "Can't Predict" |
|
if text != "": |
|
text = text_preprocessor.preprocess(pd.Series(text))[0] |
|
tokens = tokenizer(text, max_length=200, padding="max_length", truncation=True, return_tensors="tf") |
|
|
|
interpreter.allocate_tensors() |
|
input_details = interpreter.get_input_details() |
|
output_details = interpreter.get_output_details()[0] |
|
attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids'] |
|
interpreter.set_tensor(input_details[0]["index"], attention_mask) |
|
interpreter.set_tensor(input_details[1]["index"], input_ids) |
|
interpreter.invoke() |
|
tflite_pred = interpreter.get_tensor(output_details["index"])[0] |
|
tflite_pred_argmax = np.argmax(tflite_pred) |
|
tflite_pred = f"{label_map[label_encoder.inverse_transform([tflite_pred_argmax])[0]]} ({str(np.round(tflite_pred[tflite_pred_argmax], 5))})" |
|
return tflite_pred |
|
|
|
|
|
def main(): |
|
st.title("Wikipedia Article Classification") |
|
st.markdown('<p>The model is fine-tuned to classify an article into <a href="https://huggingface.co/spaces/ksvmuralidhar/wiki-article-classification/blob/main/categories.csv" target="_blank">219 categories</a></p>', |
|
unsafe_allow_html=True) |
|
review = st.text_area("Paste an article:", "", height=200) |
|
if st.button("Submit"): |
|
result = inference(review) |
|
st.write(result) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|