# Hint: this cheatsheet is magic! https://cheat-sheet.streamlit.app/ import constants import pandas as pd import streamlit as st import matplotlib.pyplot as plt from transformers import BertForSequenceClassification, AutoTokenizer import altair as alt from altair import X, Y, Scale import base64 @st.cache_data def render_svg(svg): """Renders the given svg string.""" b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") html = rf'
' c = st.container() c.write(html, unsafe_allow_html=True) @st.cache_data def convert_df(df): # IMPORTANT: Cache the conversion to prevent computation on every rerun return df.to_csv(index=None).encode("utf-8") @st.cache_resource def load_model(model_name): model = BertForSequenceClassification.from_pretrained(model_name) return model tokenizer = AutoTokenizer.from_pretrained(constants.MODEL_NAME) model = load_model(constants.MODEL_NAME) def compute_ALDi(sentences): # TODO: Perform inference in batches progress_text = "Computing ALDi..." my_bar = st.progress(0, text=progress_text) BATCH_SIZE = 4 output_logits = [] for first_index in range(0, len(sentences), BATCH_SIZE): inputs = tokenizer( sentences[first_index : first_index + BATCH_SIZE], return_tensors="pt", padding=True, ) outputs = model(**inputs).logits.reshape(-1).tolist() output_logits = output_logits + [max(min(o, 1), 0) for o in outputs] my_bar.progress( min((first_index + BATCH_SIZE) / len(sentences), 1), text=progress_text ) my_bar.empty() return output_logits render_svg(open("assets/ALDi_logo.svg").read()) tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"]) with tab1: sent = st.text_input( "Arabic Sentence:", placeholder="Enter an Arabic sentence.", on_change=None ) # TODO: Check if this is needed! clicked = st.button("Submit") if sent: ALDi_score = compute_ALDi([sent])[0] ORANGE_COLOR = "#FF8000" fig, ax = plt.subplots(figsize=(8, 1)) fig.patch.set_facecolor("none") ax.set_facecolor("none") ax.spines["left"].set_color(ORANGE_COLOR) ax.spines["bottom"].set_color(ORANGE_COLOR) ax.tick_params(axis="x", colors=ORANGE_COLOR) ax.spines[["right", "top"]].set_visible(False) ax.barh(y=[0], width=[ALDi_score], color=ORANGE_COLOR) ax.set_xlim(0, 1) ax.set_ylim(-1, 1) ax.set_title(f"ALDi score is: {round(ALDi_score, 3)}", color=ORANGE_COLOR) ax.get_yaxis().set_visible(False) ax.set_xlabel("ALDi score", color=ORANGE_COLOR) st.pyplot(fig) with tab2: file = st.file_uploader("Upload a file", type=["txt"]) if file is not None: df = pd.read_csv(file, sep="\t", header=None) df.columns = ["Sentence"] df.reset_index(drop=True, inplace=True) # TODO: Run the model df["ALDi"] = compute_ALDi(df["Sentence"].tolist()) # A horizontal rule st.markdown("""---""") chart = ( alt.Chart(df.reset_index()) .mark_area(color="darkorange", opacity=0.5) .encode( x=X(field="index", title="Sentence Index"), y=Y("ALDi", scale=Scale(domain=[0, 1])), ) ) st.altair_chart(chart.interactive(), use_container_width=True) col1, col2 = st.columns([4, 1]) with col1: # Display the output st.table( df, ) with col2: # Add a download button csv = convert_df(df) st.download_button( label=":file_folder: Download predictions as CSV", data=csv, file_name="ALDi_scores.csv", mime="text/csv", )