SabiYarn-125M / app.py
BeardedMonster's picture
update
e9dc958 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig #, TextIteratorStreamer
import torch
import time
import requests
import json
repo_name = "BeardedMonster/SabiYarn-125M"
tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
# 1. Enter your text in the text area.
# 2. Click on 'Generate' to get the generated text.
# Add sidebar with instructions
st.sidebar.title("Instructions: How to use")
st.sidebar.write("""
1.These are the tags for languages: **<yor>, <ibo>, <hau>, <efi>, <urh>, <fuv>, <eng>, <pcm>**\n
2. You can either input any string in the specified language:\n
- how are you?, Abin mamaki ne aikin da\n
Or you can prompt the model with questions or instructions using defined tags:\n
- **Instruction following:**
<prompt> who are you? <response>:, <prompt> speak yoruba <response>:\n
- **sentiment classification:**
<classify> I don't like apples <sentiment>\n
- **Topic classification:**
<classify> TEXT HERE <topic>\n
- **Title Generation:**
<title> TEXT HERE <headline>\n
- **Diacritize tex:t**
<diacritize> TEXT HERE <yor>, <diacritize> TEXT HERE <hau>\n
- **Clean text:**
<clean> TEXT HERE <hau>\n
**Note: Model's performance vary with prompts due to model size and training data distribution.**\n
3. Lastly, you can play with some of the generation parameters below to improve performance.
""")
# Define generation configuration
max_length = st.sidebar.slider("Max Length", min_value=10, max_value=500, value=100)
num_beams = st.sidebar.slider("Number of Beams", min_value=1, max_value=10, value=5)
temperature = st.sidebar.slider("Temperature", min_value=0.1, max_value=2.0, value=0.9)
top_k = st.sidebar.slider("Top-K", min_value=1, max_value=100, value=50)
top_p = st.sidebar.slider("Top-P", min_value=0.1, max_value=1.0, value=0.95)
repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=10.0, value=2.0)
length_penalty = st.sidebar.slider("Length Penalty", min_value=0.1, max_value=10.0, value=1.7)
# early_stopping = st.sidebar.selectbox("Early Stopping", [True, False], index=0)
generation_config = {
"max_length":max_length,
"num_beams":num_beams,
"do_sample":True,
"temperature":temperature,
"top_k":top_k,
"top_p":top_p,
"repetition_penalty":repetition_penalty,
"length_penalty": length_penalty,
"early_stopping":True
}
# Streamlit app
st.title("SabiYarn-125M: Generates text in multiple Nigerian languages.")
st.write("**Supported Languages: English, Yoruba, Igbo, Hausa, Pidgin, Efik, Urhobo, Fulfulde, Fulah. \nResults might not be coherent for less represented languages (i.e Efik,\
Urhobo, Fulfulde, Fulah).**")
st.write("**Model is running on CPU RAM . So, token generation might be slower (streaming not enabled).**")
st.write("**Avg Response time: 15 secs/ 50 tokens. Response time increase with input length**")
st.write("-"*50)
def generate_from_api(user_input, generation_config):
url = " https://pauljeffrey--sabiyarn-fastapi-app.modal.run/predict"
payload = json.dumps({
"prompt": user_input,
"config": generation_config
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.text
# Text input
user_input = st.text_area("Enter text below **(please, first read the instructions on how to use in the side bar)**: ", "<prompt> Tell me a story in pidgin <response>:")
if st.button("Generate"):
if user_input:
try:
st.write("**Generated Text Below:**")
# input_ids = tokenizer(user_input, return_tensors="pt")["input_ids"].to(device)
full_output = st.empty()
start_time = time.time()
# generated_text = generate_and_stream_text(input_ids, generation_config
generated_text = generate_from_api(user_input, generation_config)
end_time = time.time()
output = ""
for next_token in tokenizer.tokenize(generated_text):
output += next_token
full_output.text(generated_text)
time.sleep(2)
# full_output.text(generated_text)
time_diff = end_time - start_time
st.write("Time taken: ", time_diff , "seconds.")
except Exception as e:
st.error(f"Error during text generation: {e}")
else:
st.write("Please enter some text to generate.")