Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig #, TextIteratorStreamer | |
import torch | |
import time | |
import requests | |
import json | |
repo_name = "BeardedMonster/SabiYarn-125M" | |
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True) | |
# 1. Enter your text in the text area. | |
# 2. Click on 'Generate' to get the generated text. | |
# Add sidebar with instructions | |
st.sidebar.title("Instructions: How to use") | |
st.sidebar.write(""" | |
1.These are the tags for languages: **<yor>, <ibo>, <hau>, <efi>, <urh>, <fuv>, <eng>, <pcm>**\n | |
2. You can either input any string in the specified language:\n | |
- how are you?, Abin mamaki ne aikin da\n | |
Or you can prompt the model with questions or instructions using defined tags:\n | |
- **Instruction following:** | |
<prompt> who are you? <response>:, <prompt> speak yoruba <response>:\n | |
- **sentiment classification:** | |
<classify> I don't like apples <sentiment>\n | |
- **Topic classification:** | |
<classify> TEXT HERE <topic>\n | |
- **Title Generation:** | |
<title> TEXT HERE <headline>\n | |
- **Diacritize tex:t** | |
<diacritize> TEXT HERE <yor>, <diacritize> TEXT HERE <hau>\n | |
- **Clean text:** | |
<clean> TEXT HERE <hau>\n | |
**Note: Model's performance vary with prompts due to model size and training data distribution.**\n | |
3. Lastly, you can play with some of the generation parameters below to improve performance. | |
""") | |
# Define generation configuration | |
max_length = st.sidebar.slider("Max Length", min_value=10, max_value=500, value=100) | |
num_beams = st.sidebar.slider("Number of Beams", min_value=1, max_value=10, value=5) | |
temperature = st.sidebar.slider("Temperature", min_value=0.1, max_value=2.0, value=0.9) | |
top_k = st.sidebar.slider("Top-K", min_value=1, max_value=100, value=50) | |
top_p = st.sidebar.slider("Top-P", min_value=0.1, max_value=1.0, value=0.95) | |
repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=10.0, value=2.0) | |
length_penalty = st.sidebar.slider("Length Penalty", min_value=0.1, max_value=10.0, value=1.7) | |
# early_stopping = st.sidebar.selectbox("Early Stopping", [True, False], index=0) | |
generation_config = { | |
"max_length":max_length, | |
"num_beams":num_beams, | |
"do_sample":True, | |
"temperature":temperature, | |
"top_k":top_k, | |
"top_p":top_p, | |
"repetition_penalty":repetition_penalty, | |
"length_penalty": length_penalty, | |
"early_stopping":True | |
} | |
# Streamlit app | |
st.title("SabiYarn-125M: Generates text in multiple Nigerian languages.") | |
st.write("**Supported Languages: English, Yoruba, Igbo, Hausa, Pidgin, Efik, Urhobo, Fulfulde, Fulah. \nResults might not be coherent for less represented languages (i.e Efik,\ | |
Urhobo, Fulfulde, Fulah).**") | |
st.write("**Model is running on CPU RAM . So, token generation might be slower (streaming not enabled).**") | |
st.write("**Avg Response time: 15 secs/ 50 tokens. Response time increase with input length**") | |
st.write("-"*50) | |
def generate_from_api(user_input, generation_config): | |
url = " https://pauljeffrey--sabiyarn-fastapi-app.modal.run/predict" | |
payload = json.dumps({ | |
"prompt": user_input, | |
"config": generation_config | |
}) | |
headers = { | |
'Content-Type': 'application/json' | |
} | |
response = requests.request("POST", url, headers=headers, data=payload) | |
return response.text | |
# Text input | |
user_input = st.text_area("Enter text below **(please, first read the instructions on how to use in the side bar)**: ", "<prompt> Tell me a story in pidgin <response>:") | |
if st.button("Generate"): | |
if user_input: | |
try: | |
st.write("**Generated Text Below:**") | |
# input_ids = tokenizer(user_input, return_tensors="pt")["input_ids"].to(device) | |
full_output = st.empty() | |
start_time = time.time() | |
# generated_text = generate_and_stream_text(input_ids, generation_config | |
generated_text = generate_from_api(user_input, generation_config) | |
end_time = time.time() | |
output = "" | |
for next_token in tokenizer.tokenize(generated_text): | |
output += next_token | |
full_output.text(generated_text) | |
time.sleep(2) | |
# full_output.text(generated_text) | |
time_diff = end_time - start_time | |
st.write("Time taken: ", time_diff , "seconds.") | |
except Exception as e: | |
st.error(f"Error during text generation: {e}") | |
else: | |
st.write("Please enter some text to generate.") | |