bentrevett commited on
Commit
33f47d9
1 Parent(s): b5a8c50

improved ux when pipe is loading

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -3,7 +3,7 @@ import transformers
3
  import matplotlib.pyplot as plt
4
 
5
 
6
- @st.cache(allow_output_mutation=True)
7
  def get_pipe():
8
  model_name = "joeddav/distilbert-base-uncased-go-emotions-student"
9
  model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name)
@@ -21,12 +21,12 @@ st.set_page_config(page_title="Emotion Prediction")
21
  st.title("Emotion Prediction")
22
  st.write("Type text into the text box and then press 'Predict' to get the predicted emotion.")
23
 
24
- with st.spinner("Loading model..."):
25
- pipe = get_pipe()
26
-
27
  text = st.text_area('Enter text here:')
28
  submit = st.button('Predict')
29
 
 
 
 
30
  if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:
31
 
32
  prediction = pipe(text)[0]
 
3
  import matplotlib.pyplot as plt
4
 
5
 
6
+ @st.cache(allow_output_mutation=True, show_spinner=False)
7
  def get_pipe():
8
  model_name = "joeddav/distilbert-base-uncased-go-emotions-student"
9
  model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name)
 
21
  st.title("Emotion Prediction")
22
  st.write("Type text into the text box and then press 'Predict' to get the predicted emotion.")
23
 
 
 
 
24
  text = st.text_area('Enter text here:')
25
  submit = st.button('Predict')
26
 
27
+ with st.spinner("Loading model..."):
28
+ pipe = get_pipe()
29
+
30
  if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:
31
 
32
  prediction = pipe(text)[0]