File size: 3,877 Bytes
52e06b8
12259bb
d10e7c0
4b1a870
 
 
 
 
 
52e06b8
12259bb
8cc1975
12259bb
 
52e06b8
aedc2eb
32c1879
12259bb
 
 
8cc1975
e1e0964
12259bb
8cc1975
52e06b8
12d9ade
 
 
 
 
 
 
 
 
 
 
 
 
52e06b8
 
e1e0964
52e06b8
 
 
 
b7ac79a
52e06b8
 
 
12259bb
 
 
8cc1975
4b1a870
 
644a4cc
12d9ade
4b1a870
 
e1e0964
4b1a870
12d9ade
8cc1975
12d9ade
12259bb
12d9ade
 
12259bb
 
12d9ade
 
 
 
 
 
 
 
 
 
 
 
e1e0964
12d9ade
 
4b1a870
 
12d9ade
4b1a870
 
e1e0964
12d9ade
 
 
 
694bc80
12d9ade
8cc1975
12d9ade
 
 
6da03bd
a535c5a
 
fab8e58
 
6da03bd
 
12d9ade
6da03bd
 
 
 
 
 
 
 
 
 
 
12d9ade
6da03bd
 
 
 
 
 
 
 
 
 
 
12d9ade
 
 
 
 
 
 
 
 
 
 
6da03bd
fab8e58
 
12d9ade
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import spaces
import os
import json
import subprocess
from llama_cpp import Llama
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
import gradio as gr
from huggingface_hub import hf_hub_download

huggingface_token = os.getenv("HUGGINGFACE_TOKEN")

hf_hub_download(
    repo_id="SakuraLLM/Sakura-14B-Qwen2beta-v0.9.2-GGUF",
    filename="sakura-14b-qwen2beta-v0.9.2-q6k.gguf",
    local_dir="./models",
    token=huggingface_token
)

llm = None
llm_model = None

@spaces.GPU(duration=120)
def respond(
    message,
    history: list[tuple[str, str]],
    model,
    system_message,
    max_tokens,
    temperature,
    top_p,
    top_k,
    repeat_penalty,
):
    chat_template = MessagesFormatterType.GEMMA_2

    global llm
    global llm_model
    
    if llm is None or llm_model != model:
        llm = Llama(
            model_path=f"models/{model}",
            flash_attn=True,
            n_gpu_layers=81,
            n_batch=1024,
            n_ctx=8192,
        )
        llm_model = model

    provider = LlamaCppPythonProvider(llm)

    agent = LlamaCppAgent(
        provider,
        system_prompt=f"{system_message}",
        predefined_messages_formatter_type=chat_template,
        debug_output=True
    )
    
    settings = provider.get_provider_default_settings()
    settings.temperature = 0.1
    settings.top_k = top_k
    settings.top_p = 0.3
    settings.max_tokens = max_tokens
    settings.repeat_penalty = 1.0
    settings.stream = True

    messages = BasicChatHistory()

    for msn in history:
        user = {
            'role': Roles.user,
            'content': msn[0]
        }
        assistant = {
            'role': Roles.assistant,
            'content': msn[1]
        }
        messages.add_message(user)
        messages.add_message(assistant)
    
    stream = agent.get_chat_response(
        message,
        llm_sampling_settings=settings,
        chat_history=messages,
        returns_streaming_generator=True,
        print_output=False
    )
    
    outputs = ""
    for output in stream:
        outputs += output
        yield outputs

description = """<p align="center">Defaults to Sakura-14B-Qwen2beta (you can switch from additional inputs)</p>"""

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Dropdown([
                'sakura-14b-qwen2beta-v0.9.2-q6k.gguf'
            ],
            value="sakura-14b-qwen2beta-v0.9.2-q6k.gguf",
            label="Model"
        ),
        gr.Textbox(value="你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。", label="System message"),
        gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.3,
            step=0.05,
            label="Top-p",
        ),
        gr.Slider(
            minimum=0,
            maximum=100,
            value=40,
            step=1,
            label="Top-k",
        ),
        gr.Slider(
            minimum=0.0,
            maximum=2.0,
            value=1.0,
            step=0.1,
            label="Repetition penalty",
        ),
    ],
    retry_btn="Retry",
    undo_btn="Undo",
    clear_btn="Clear",
    submit_btn="Send",
    title="Chat with Sakura-14B-Qwen2beta using llama.cpp", 
    description=description,
    chatbot=gr.Chatbot(
        scale=1, 
        likeable=False,
        show_copy_button=True
    )
)

if __name__ == "__main__":
    demo.launch()