import os import requests from dotenv import load_dotenv import gradio as gr import random from text_generation import Client # Assumed custom package # Load environment variables load_dotenv() hf_api_key = os.environ['HF_API_KEY'] # Initialize the client client = Client(os.environ['HF_API_FALCOM_BASE'], headers={"Authorization": f"Basic {hf_api_key}"}, timeout=120) # Text generation function def generate(input_text, max_tokens): return client.generate(input_text, max_new_tokens=max_tokens).generated_text # Gradio interface for text generation demo_text_gen = gr.Interface(fn=generate, inputs=[gr.Textbox(label="Prompt"), gr.Slider(label="Max new tokens", value=20, maximum=1024, minimum=1)], outputs=gr.Textbox(label="Generated Text")) # Chat history management def format_chat_prompt(message, chat_history): prompt = "" for user_msg, bot_msg in chat_history: prompt += f"\nUser: {user_msg}\nAssistant: {bot_msg}" return f"{prompt}\nUser: {message}\nAssistant:" # Chatbot response generation def respond(message, chat_history, instruction, temperature=0.7): prompt = format_chat_prompt(message, chat_history, instruction) response = client.generate(prompt, max_new_tokens=1024, stop_sequences=["\nUser:", ""], temperature=temperature) chat_history.append((message, response.generated_text)) return response.generated_text, chat_history # Gradio interface for chatbot with gr.Blocks() as demo_chatbot: chatbot = gr.Chatbot() msg = gr.Textbox(label="Your Message") system_msg = gr.Textbox(label="System Instruction", value="A conversation with an AI.") temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1, value=0.7) submit_btn = gr.Button("Send") chat_history = [] submit_btn.click(respond, inputs=[msg, chat_history, system_msg, temperature_slider], outputs=[chatbot]) msg.submit(respond, inputs=[msg, chat_history, system_msg, temperature_slider], outputs=[chatbot]) # Launch Gradio apps if __name__ == "__main__": gr.close_all() demo_text_gen.launch(server_port=int(os.environ.get('PORT1', 7860))) demo_chatbot.launch(server_port=int(os.environ.get('PORT2', 7861)))