ksvmuralidhar's picture
Create app.py
4c44d2b
import streamlit as st
import pandas as pd
import numpy as np
from unidecode import unidecode
import tensorflow as tf
import cloudpickle
from transformers import DistilBertTokenizerFast
import os
def load_model():
interpreter = tf.lite.Interpreter(model_path=os.path.join("models/dbpedia_classifier_hf_distilbert_l3.tflite"))
with open("models/preprocessor_labelencoder_l3.bin", "rb") as model_file_obj:
text_preprocessor, label_encoder = cloudpickle.load(model_file_obj)
with open("models/label_map_l3.bin", "rb") as model_file_obj:
label_map = cloudpickle.load(model_file_obj)
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
return interpreter, text_preprocessor, label_encoder, tokenizer, label_map
interpreter, text_preprocessor, label_encoder, tokenizer, label_map = load_model()
def inference(text):
tflite_pred = "Can't Predict"
if text != "":
text = text_preprocessor.preprocess(pd.Series(text))[0]
tokens = tokenizer(text, max_length=200, padding="max_length", truncation=True, return_tensors="tf")
# tflite model inference
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()[0]
attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
interpreter.set_tensor(input_details[0]["index"], attention_mask)
interpreter.set_tensor(input_details[1]["index"], input_ids)
interpreter.invoke()
tflite_pred = interpreter.get_tensor(output_details["index"])[0]
tflite_pred_argmax = np.argmax(tflite_pred)
tflite_pred = f"{label_map[label_encoder.inverse_transform([tflite_pred_argmax])[0]]} ({str(np.round(tflite_pred[tflite_pred_argmax], 5))})"
return tflite_pred
def main():
st.title("Wikipedia Article Classification")
st.markdown('<p>The model is fine-tuned to classify an article into <a href="https://huggingface.co/spaces/ksvmuralidhar/wiki-article-classification/blob/main/categories.csv" target="_blank">219 categories</a></p>',
unsafe_allow_html=True)
review = st.text_area("Paste an article:", "", height=200)
if st.button("Submit"):
result = inference(review)
st.write(result)
if __name__ == "__main__":
main()