import gradio as gr import os from huggingface_hub.file_download import http_get from llama_cpp import Llama SYSTEM_PROMPT = "You are Apollo, a multilingual medical model. You communicate with people and assist them." # Define the directory dynamically dir = "." def get_message_tokens(model, role, content): content = f"{role}\n{content}\n" content = content.encode("utf-8") return model.tokenize(content, special=True) def get_system_tokens(model): system_message = {"role": "system", "content": SYSTEM_PROMPT} return get_message_tokens(model, **system_message) def load_model(directory, model_name, model_url): final_model_path = os.path.join(directory, model_name) print(f"Checking model: {model_name}") if not os.path.exists(final_model_path): print(f"Downloading model: {model_name}") with open(final_model_path, "wb") as f: http_get(model_url, f) os.chmod(final_model_path, 0o777) print(f"Model {model_name} ready!") model = Llama(model_path=final_model_path, n_ctx=1024) print(f"Model {model_name} loaded successfully!") return model MODEL_OPTIONS = { "Apollo 0.5B": { "directory": dir, "model_name": "apollo-0.5b.gguf", "model_url": "https://huggingface.co/path_to_apollo_0.5b_model" }, "Apollo 2B": { "directory": dir, "model_name": "apollo-2b.gguf", "model_url": "https://huggingface.co/path_to_apollo_2b_model" }, "Apollo 7B": { "directory": dir, "model_name": "Apollo-7B-q8_0.gguf", "model_url": "https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF/resolve/main/Apollo-7B-q8_0.gguf" }, "Apollo2 0.5B": { "directory": dir, "model_name": "Apollo-0.5B-q8_0.gguf", "model_url": "https://huggingface.co/FreedomIntelligence/Apollo-0.5B-GGUF/resolve/main/Apollo-0.5B-q8_0.gguf" }, "Apollo2 2B": { "directory": dir, "model_name": "Apollo-2B-q8_0.gguf", "model_url": "https://huggingface.co/FreedomIntelligence/Apollo-2B-GGUF/resolve/main/Apollo-2B-q8_0.gguf" }, "Apollo2 7B": { "directory": dir, "model_name": "apollo2-7b-q8_0.gguf", "model_url": "https://huggingface.co/nchen909/Apollo2-7B-Q8_0-GGUF/resolve/main/apollo2-7b-q8_0.gguf" } } MODEL = None CURRENT_MODEL_KEY = None def get_model_key(model_type, model_size): return f"{model_type} {model_size}" def initialize_model(model_type, model_size): """Load the selected model dynamically.""" global MODEL, CURRENT_MODEL_KEY model_key = get_model_key(model_type, model_size) # Only reload the model if it's not already loaded if CURRENT_MODEL_KEY == model_key and MODEL is not None: print(f"Model {model_key} is already loaded.") return print(f"Initializing model: {model_key}") try: selected_model = MODEL_OPTIONS[model_key] MODEL = load_model( directory=selected_model["directory"], model_name=selected_model["model_name"], model_url=selected_model["model_url"] ) CURRENT_MODEL_KEY = model_key print(f"Model initialized: {model_key}") except Exception as e: print(f"Failed to initialize model {model_key}: {e}") MODEL = None # Functions for chat interactions def user(message, history, model_type, model_size): """Handle user input and dynamically initialize the selected model.""" global MODEL # Dynamically initialize the selected model initialize_model(model_type, model_size) new_history = history + [[message, None]] return "", new_history def bot(history, top_p, top_k, temp): """Generate a response from the bot based on chat history.""" global MODEL if MODEL is None: raise RuntimeError("Model has not been initialized. Please select a model to load.") model = MODEL tokens = get_system_tokens(model)[:] for user_message, bot_message in history[:-1]: tokens.extend(get_message_tokens(model=model, role="user", content=user_message)) if bot_message: tokens.extend(get_message_tokens(model=model, role="bot", content=bot_message)) last_user_message = history[-1][0] tokens.extend(get_message_tokens(model=model, role="user", content=last_user_message)) tokens.extend(model.tokenize("bot\n".encode("utf-8"), special=True)) generator = model.generate(tokens, top_k=top_k, top_p=top_p, temp=temp) partial_text = "" for i, token in enumerate(generator): if token == model.token_eos(): break partial_text += model.detokenize([token]).decode("utf-8", "ignore") history[-1][1] = partial_text yield history def clear_chat(): """Clear the chat history.""" return [] def stop_generation(): """Placeholder to stop generation.""" print("Generation stopped.") # Implement stop logic if supported return None # Gradio UI with gr.Blocks(theme=gr.themes.Monochrome(), analytics_enabled=False) as demo: favicon = '' gr.Markdown( f"""# {favicon} Apollo GGUF Playground This is a demo of multilingual medical model series **[Apollo](https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF)**, GGUF version. [Apollo1](https://arxiv.org/abs/2403.03640) covers 6 languages. [Apollo2](https://arxiv.org/abs/2410.10626) covers 50 languages. """ ) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(label="Conversation") msg = gr.Textbox( label="Send Message", placeholder="Send Message", show_label=False, elem_id="send-message-box" ) with gr.Column(scale=1): with gr.Row(equal_height=False): model_type = gr.Dropdown( choices=["Apollo", "Apollo2"], value="Apollo2", label="Select Model", interactive=True, elem_id="model-type-dropdown", ) model_size = gr.Dropdown( choices=["0.5B", "2B", "7B"], value="7B", label="Select Size", interactive=True, elem_id="model-size-dropdown", ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.9, step=0.05, interactive=True, label="Top-p", ) top_k = gr.Slider( minimum=10, maximum=100, value=30, step=5, interactive=True, label="Top-k", ) temp = gr.Slider( minimum=0.0, maximum=2.0, value=0.01, step=0.01, interactive=True, label="Temperature" ) with gr.Row(equal_height=False): submit = gr.Button("Send", elem_id="send-btn") stop = gr.Button("Stop", elem_id="stop-btn") clear = gr.Button("Clear", elem_id="clear-btn") # Event bindings submit_event = msg.submit( fn=user, inputs=[msg, chatbot, model_type, model_size], outputs=[msg, chatbot], queue=False, ).success( fn=bot, inputs=[chatbot, top_p, top_k, temp], outputs=chatbot, queue=True, ) submit_click_event = submit.click( fn=user, inputs=[msg, chatbot, model_type, model_size], outputs=[msg, chatbot], queue=False, ).success( fn=bot, inputs=[chatbot, top_p, top_k, temp], outputs=chatbot, queue=True, ) stop.click( fn=stop_generation, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False, ) clear.click(fn=clear_chat, inputs=None, outputs=chatbot, queue=False) demo.queue(max_size=128) demo.css = """ footer {display: none !important;} #send-message-box {width: 100%;} #send-btn, #stop-btn, #clear-btn { display: inline-block; width: 30%; margin-right: 2px; text-align: center; } .gr-row { display: flex !important; flex-direction: row !important; justify-content: space-between; align-items: center; flex-wrap: nowrap; } """ # Initialize # Initialize the default model at startup #initialize_model("Apollo2", "7B") demo.launch(show_error=True, share=True)