Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import pipeline | |
from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
def summarize(data, modelname): | |
if (modelname == 'Bart'): | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
print("world") | |
output = summarizer(data, max_length=130, min_length=30, do_sample=False) | |
return output[0]["summary_text"] | |
elif (modelname == 'Pegasus'): | |
model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum") | |
tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum") | |
# Create tokens - number representation of our text | |
tokens = tokenizer(data, truncation=True, padding="longest", return_tensors="pt") | |
summary = model.generate(**tokens) | |
return tokenizer.decode(summary[0]) | |
st.sidebar.title("Text Summarization") | |
uploaded_file = st.sidebar.file_uploader("Choose a file") | |
data = "" | |
output = "" | |
if uploaded_file is not None: | |
# To read file as bytes: | |
bytes_data = uploaded_file.getvalue() | |
data = bytes_data.decode("utf-8") | |
modelname = st.sidebar.radio("Choose your model", | |
["Bart", "Pegasus"], | |
help=" you can choose between 2 models (Bart or Pegasus) to summarize your text. More to come!", ) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.header("Copy paste your text or Upload file") | |
if (uploaded_file is not None): | |
with st.expander("Text to summarize", expanded=True): | |
st.write( | |
data | |
) | |
else: | |
with st.expander("Text to summarize", expanded=True): | |
data = st.text_area("Paste your text below (max 500 words)", height=510, ) | |
MAX_WORDS = 500 | |
import re | |
res = len(re.findall(r"\w+", data)) | |
if res > MAX_WORDS: | |
st.warning( | |
"β οΈ Your text contains " | |
+ str(res) | |
+ " words." | |
+ " Only the first 500 words will be reviewed. Stay tuned as increased allowance is coming! π") | |
data = data[:MAX_WORDS] | |
Summarizebtn = st.button("Summarize") | |
if (Summarizebtn): | |
output = summarize(data, modelname) | |
with col2: | |
st.header("Summary") | |
if (len(output) > 0): | |
with st.expander("", expanded=True): | |
st.write(output) | |