Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
# Select the models you want to offer for chat | |
MODELS = [ | |
"gpt2", | |
"distilgpt2", | |
"openai-gpt", | |
"openai-gpt-2", | |
"openai-gpt3", | |
] | |
# Define the system prompt | |
SYSTEM_PROMPT = "You are a helpful assistant. Answer the user's questions as best as you can." | |
# Create a dictionary to store conversation history | |
conversation_history = {} | |
# Create a function to generate the chatbot response | |
def chatbot_response(input_text, model_name, system_prompt): | |
# Load the selected model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Define the chatbot pipeline | |
chatbot_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
# Prepare the input for the chatbot | |
inputs = tokenizer([SYSTEM_PROMPT + " " + input_text], return_tensors="pt") | |
# Generate a response from the chatbot | |
response = chatbot_pipe(inputs) | |
# Return the response | |
return response[0]['generated_text'].strip() | |
# Create a Gradio interface | |
interface = gr.Interface( | |
fn=chatbot_response, | |
inputs=[ | |
gr.inputs.Textbox(label="User input"), | |
gr.inputs.Radio(choices=MODELS, label="Model"), | |
gr.inputs.Textbox(label="System prompt", value=SYSTEM_PROMPT), | |
], | |
outputs="text", | |
title="Large Language Model Chatbot", | |
description="Chat with a large language model from the HuggingFace Transformers library.", | |
) | |
# Initialize the conversation history | |
for model in MODELS: | |
conversation_history[model] = [] | |
# Define a function to update the conversation history | |
def update_history(history, new_message): | |
history.append(new_message) | |
return history | |
# Define a function to display the conversation history | |
def display_history(history): | |
return "\n".join(history) | |
# Create a Gradio block to display the conversation history | |
history_block = gr.Block( | |
label="Conversation History", | |
elem_id="history", | |
visible=False, | |
) | |
# Update the conversation history when a new message is sent | |
def update_history_on_message(history, model, new_message): | |
history = update_history(history, new_message) | |
conversation_history[model] = history | |
return history | |
# Display the conversation history | |
def display_history_on_message(history): | |
return display_history(history) | |
# Define event handlers for the Gradio interface | |
interface.change(history_block.update, [conversation_history], queue=False) | |
interface.submit(update_history_on_message, [conversation_history], [conversation_history], queue=False) | |
history_block.change(display_history_on_message, [conversation_history], queue=False) | |
# Launch the Gradio interface | |
interface.launch() |