import gradio as gr
from llm_rs import AutoModel,SessionConfig,GenerationConfig,Precision
repo_name = "rustformers/mpt-7b-ggml"
file_name = "mpt-7b-storywriter-q5_1.bin"
examples = [
"Write a travel blog about a 3-day trip to Thailand.",
"Tell me a short story about a robot that has a nice day.",
"Compose a tweet to congratulate rustformers on the launch of their HuggingFace Space.",
"Explain how a candle works to a 6-year-old in a few sentences.",
"What are some of the most common misconceptions about birds?",
"Explain why the Rust programming language is so popular.",
]
session_config = SessionConfig(threads=2,batch_size=2)
model = AutoModel.from_pretrained(repo_name, model_file=file_name, session_config=session_config,verbose=True)
def process_stream(instruction, temperature, top_p, top_k, max_new_tokens, seed):
prompt=f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
Answer:"""
generation_config = GenerationConfig(seed=seed,temperature=temperature,top_p=top_p,top_k=top_k,max_new_tokens=max_new_tokens)
response = ""
streamer = model.stream(prompt=prompt,generation_config=generation_config)
for new_text in streamer:
response += new_text
yield response
with gr.Blocks(
theme=gr.themes.Soft(),
css=".disclaimer {font-variant-caps: all-small-caps;}",
) as demo:
gr.Markdown(
"""
MPT-7B-Instruct on CPU in Rust 🦀
This demo uses the [rustformers/llm](https://github.com/rustformers/llm) library via [llm-rs](https://github.com/LLukas22/llm-rs-python) to execute [MPT-7B-Instruct](https://huggingface.co/mosaicml/mpt-7b-instruct) on 2 CPU cores.
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
instruction = gr.Textbox(
placeholder="Enter your question or instruction here",
label="Question/Instruction",
elem_id="q-input",
)
with gr.Accordion("Advanced Options:", open=False):
with gr.Row():
with gr.Column():
with gr.Row():
temperature = gr.Slider(
label="Temperature",
value=0.8,
minimum=0.1,
maximum=1.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
with gr.Column():
with gr.Row():
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.95,
minimum=0.0,
maximum=1.0,
step=0.01,
interactive=True,
info=(
"Sample from the smallest possible set of tokens whose cumulative probability "
"exceeds top_p. Set to 1 to disable and sample from all tokens."
),
)
with gr.Column():
with gr.Row():
top_k = gr.Slider(
label="Top-k",
value=40,
minimum=5,
maximum=80,
step=1,
interactive=True,
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
)
with gr.Column():
with gr.Row():
max_new_tokens = gr.Slider(
label="Maximum new tokens",
value=256,
minimum=0,
maximum=1024,
step=5,
interactive=True,
info="The maximum number of new tokens to generate",
)
with gr.Column():
with gr.Row():
seed = gr.Number(
label="Seed",
value=42,
interactive=True,
info="The seed to use for the generation",
precision=0
)
with gr.Row():
submit = gr.Button("Submit")
with gr.Row():
with gr.Box():
gr.Markdown("**MPT-7B-Instruct**")
output_7b = gr.Markdown()
with gr.Row():
gr.Examples(
examples=examples,
inputs=[instruction],
cache_examples=False,
fn=process_stream,
outputs=output_7b,
)
with gr.Row():
gr.Markdown(
"Disclaimer: MPT-7B can produce factually incorrect output, and should not be relied on to produce "
"factually accurate information. MPT-7B was trained on various public datasets; while great efforts "
"have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
"biased, or otherwise offensive outputs.",
elem_classes=["disclaimer"],
)
with gr.Row():
gr.Markdown(
"[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)",
elem_classes=["disclaimer"],
)
submit.click(
process_stream,
inputs=[instruction, temperature, top_p, top_k, max_new_tokens,seed],
outputs=output_7b,
)
instruction.submit(
process_stream,
inputs=[instruction, temperature, top_p, top_k, max_new_tokens,seed],
outputs=output_7b,
)
demo.queue(max_size=4, concurrency_count=1).launch(debug=True)