|
import streamlit as st |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
st.title("Emotion Detection with Transformers") |
|
|
|
|
|
user_input = st.text_area("Enter your text:") |
|
|
|
|
|
|
|
@st.cache_data() |
|
def load_model_and_tokenizer(): |
|
model_name = "mrm8488/t5-base-finetuned-emotion" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
return tokenizer, model |
|
|
|
|
|
tokenizer, model = load_model_and_tokenizer() |
|
|
|
|
|
|
|
def analyze_emotion(text): |
|
if text.strip() == "": |
|
return "Please enter some text to analyze." |
|
|
|
input_ids = tokenizer.encode(text + '</s>', return_tensors='pt') |
|
|
|
output = model.generate(input_ids=input_ids, |
|
max_length=2) |
|
|
|
dec = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output] |
|
label = dec[0] |
|
|
|
return f"Emotion: {label.capitalize()}" |
|
|
|
|
|
|
|
if st.button("Analyze Emotion"): |
|
result = analyze_emotion(user_input) |
|
st.write(result) |
|
|