File size: 5,009 Bytes
3137439
09f8eba
13a7a5d
3137439
 
 
 
 
 
 
 
09f8eba
 
 
 
 
3137439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6592a3f
3137439
 
6592a3f
3137439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM
import os
import json
import time
import logging
from threading import Lock

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

messages = [
    {"role": "user", "content": "Who are you?"},
]
pipe = pipeline("text-generation", model="mistralai/Mixtral-8x7B-Instruct-v0.1")
pipe(messages)

class EnhancedChatbot:
    def __init__(self):
        self.model = None
        self.config = self.load_config()
        self.model_lock = Lock()
        self.load_model()

    def load_config(self):
        if os.path.exists(CONFIG_FILE):
            with open(CONFIG_FILE, 'r') as f:
                return json.load(f)
        return {
            "model_name": MODEL_NAME,
            "max_tokens": 512,
            "temperature": 0.7,
            "top_p": 0.95,
            "system_message": "You are a friendly and helpful AI assistant.",
            "gpu_layers": 0
        }

    def save_config(self):
        with open(CONFIG_FILE, 'w') as f:
            json.dump(self.config, f, indent=2)

    def load_model(self):
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                self.config["model_name"],
                model_type="llama",
                gpu_layers=self.config["gpu_layers"],
                cache_dir=CACHE_DIR
            )
            logging.info(f"Model loaded successfully: {self.config['model_name']}")
        except Exception as e:
            logging.error(f"Error loading model: {str(e)}")
            raise

    def generate_response(self, message, history, system_message, max_tokens, temperature, top_p):
        prompt = f"{system_message}\n\n"
        for user_msg, assistant_msg in history:
            prompt += f"Human: {user_msg}\nAssistant: {assistant_msg}\n"
        prompt += f"Human: {message}\nAssistant: "

        start_time = time.time()
        with self.model_lock:
            generated_text = self.model(
                prompt,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
            )
        end_time = time.time()

        response_time = end_time - start_time
        logging.info(f"Response generated in {response_time:.2f} seconds")

        return generated_text.strip()

chatbot = EnhancedChatbot()

def respond(message, history, system_message, max_tokens, temperature, top_p):
    try:
        response = chatbot.generate_response(message, history, system_message, max_tokens, temperature, top_p)
        yield response
    except Exception as e:
        logging.error(f"Error generating response: {str(e)}")
        yield "I apologize, but I encountered an error while processing your request. Please try again."

def update_model_config(model_name, gpu_layers):
    chatbot.config["model_name"] = model_name
    chatbot.config["gpu_layers"] = gpu_layers
    chatbot.save_config()
    chatbot.load_model()
    return f"Model updated to {model_name} with {gpu_layers} GPU layers."

def update_system_message(system_message):
    chatbot.config["system_message"] = system_message
    chatbot .save_config()
    return f"System message updated: {system_message}"

with gr.Blocks() as demo:
    gr.Markdown("# Enhanced AI Chatbot")
    
    with gr.Tab("Chat"):
        chatbot_interface= gr.ChatInterface(
            respond,
            additional_inputs=[
                gr.Textbox(value=chatbot.config["system_message"], label="System message"),
                gr.Slider(minimum=1, maximum=2048, value=chatbot.config["max_tokens"], step=1, label="Max new tokens"),
                gr.Slider(minimum=0.1, maximum=4.0, value=chatbot.config["temperature"], step=0.1, label="Temperature"),
                gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=chatbot.config["top_p"],
                    step=0.05,
                    label="Top-p (nucleus sampling)",
                ),
            ],
        )
    
    with gr.Tab("Settings"):
        with gr.Group():
            gr.Markdown("### Model Settings")
            model_name_input = gr.Textbox(value=chatbot.config["model_name"], label="Model name")
            gpu_layers_input = gr.Slider(minimum=0, maximum=8, value=chatbot.config["gpu_layers"], step=1, label="GPU layers")
            update_model_button = gr.Button("Update model")
            update_model_button.click(update_model_config, inputs=[model_name_input, gpu_layers_input], outputs="text")
        
        with gr.Group():
            gr.Markdown("### System Message Settings")
            system_message_input = gr.Textbox(value=chatbot.config["system_message"], label="System message")
            update_system_message_button = gr.Button("Update system message")
            update_system_message_button.click(update_system_message, inputs=[system_message_input], outputs="text")

if __name__ == "__main__":
    demo.launch()