raghunc0
model
60616b8
raw
history blame
1.52 kB
import gradio as gr
from pathlib import Path
import torch
from tsai_gpt.generate_for_app import generate_for_app
pythia_model = "checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"
def generate_text(prompt):
generated_text = generate_for_app(prompt, num_samples=1, max_new_tokens=200, temperature=0.9, checkpoint_dir=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/"))
return generated_text
#gr.Interface(fn=generate_nanogpt_text, inputs=gr.Button(value="Generate text!"), outputs='text').launch(share=True)
with gr.Blocks() as demo:
gr.Markdown(
"""
# Example of text generation with our pythia 160M model based on the RedPajama sample data:
The model checkpoint is the 'checkpoints/meta-llama/Llama-2-7b-chat-hf' dir. The hyper params used are the exact same emitted by the training main.ipynb notebook. The loss is less than 3.5; we can see syntactically correct but semantically meaningless sentences.
Keep in mind the output is limited to 250 tokens so the inference runs within reasonable time (10s) on CPU. (Huggingface free tier)
GPU inference can output much much longer sequences.
Click on the "Generate text" button to see the generated text.
""")
generate_button = gr.Button("Generate text!")
input=gr.Textbox(label="Enter your prompt here")
output = gr.Textbox(label="Text generated by Pythia 160M trained model")
generate_button.click(fn=generate_text, inputs=input, outputs=output, api_name='text generation sample')
demo.launch()