santiviquez's picture
remove dissable button
a065117
raw
history blame
2.54 kB
import streamlit as st
from transformers import pipeline
import pandas as pd
import nannyml as nml
if 'count' not in st.session_state:
st.session_state.count = 0
def increment_counter():
st.session_state.count += 1
@st.cache_resource
def get_model(url):
tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
return pipeline(model=url, **tokenizer_kwargs)
rating_classification_model = get_model("NannyML/amazon-reviews-sentiment-bert-base-uncased-6000-samples")
label_mapping = {
'LABEL_0': 'Negative',
'LABEL_1': 'Neutral',
'LABEL_2': 'Positive'
}
review = st.text_input(label='write a review', value='I love this book!')
single_review_button = st.button(label='Classify Single Review')
if review and single_review_button:
rating = rating_classification_model(review)[0]
label = label_mapping[rating['label']]
score = rating['score']
st.write(f"{label} β€” confidence: {round(score, 2)}")
# # # # # # # #
reference_df = pd.read_csv('reference.csv')
analysis_df = pd.read_csv('analysis.csv')
reference_df['label'] = reference_df['label'].astype(str)
reference_df['pred_label'] = reference_df['pred_label'].astype(str)
analysis_df['label'] = analysis_df['label'].astype(str)
analysis_df['pred_label'] = analysis_df['pred_label'].astype(str)
estimator = nml.CBPE(
y_pred_proba={
'0': 'pred_proba_label_negative',
'1': 'pred_proba_label_neutral',
'2': 'pred_proba_label_positive'},
y_pred='pred_label',
y_true='label',
problem_type='classification_multiclass',
metrics='f1',
chunk_size=400,
)
estimator.fit(reference_df)
calculator = nml.PerformanceCalculator(
y_pred_proba={
'0': 'pred_proba_label_negative',
'1': 'pred_proba_label_neutral',
'2': 'pred_proba_label_positive'},
y_true='label',
y_pred='pred_label',
problem_type='classification_multiclass',
metrics=['f1'],
chunk_size=400,
)
calculator.fit(reference_df)
multiple_reviews_button = st.button('Estimate Model Performance on 400 Reviews', on_click=increment_counter)
if multiple_reviews_button:
prod_data = analysis_df[0: st.session_state.count * 400]
results = estimator.estimate(prod_data.drop(columns=['label']))
realize_results = calculator.calculate(prod_data)
fig = results.compare(realize_results).plot()
st.plotly_chart(fig, use_container_width=True, theme=None)
st.write(f'Batch {st.session_state.count} / 5')
if st.session_state.count >= 5:
st.session_state.count = 0