sadzxctv commited on
Commit
e1e0964
1 Parent(s): d9919e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -60
app.py CHANGED
@@ -1,86 +1,78 @@
1
  import os
2
- import spaces
3
  import json
4
  import subprocess
 
5
  from llama_cpp import Llama
6
  from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
7
  from llama_cpp_agent.providers import LlamaCppPythonProvider
8
  from llama_cpp_agent.chat_history import BasicChatHistory
9
  from llama_cpp_agent.chat_history.messages import Roles
10
- import gradio as gr
11
  from huggingface_hub import hf_hub_download
12
- import logging
13
-
14
- # 設置日誌
15
- logging.basicConfig(level=logging.INFO)
16
- logger = logging.getLogger(__name__)
17
 
18
- # 設置參數
19
- REPO_ID = os.getenv("REPO_ID", "SakuraLLM/Sakura-14B-Qwen2beta-v0.9.2-GGUF")
20
- FILENAME = os.getenv("FILENAME", "sakura-14b-qwen2beta-v0.9.2-q4km.gguf")
21
- SYSTEM_MESSAGE = os.getenv("SYSTEM_MESSAGE", "你是一个轻小说翻译模型,可以流畅通顺地使用给定的术语表以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要混淆使役态和被动态的主语和宾语,不要擅自添加原文中没有的代词,也不要擅自增加或减少换行。")
22
  MODEL_DIR = "./models"
 
 
 
23
 
24
- # 下載模型
25
- def download_model(repo_id, filename, local_dir):
26
- logger.info(f"Downloading model {filename} from {repo_id} to {local_dir}")
27
- hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
28
 
29
- # 初始化 Llama 模型
30
- def initialize_llama(model_path):
31
- logger.info(f"Initializing Llama model from {model_path}")
32
- return Llama(
33
- model_path=model_path,
34
- flash_attn=True,
35
- n_gpu_layers=81,
36
- n_batch=1024,
37
- n_ctx=8192,
38
- )
39
 
40
- # 處理歷史消息
41
- def process_history(history):
42
- messages = BasicChatHistory()
43
- for msn in history:
44
- user = {
45
- 'role': Roles.user,
46
- 'content': "根据以下术语表(可以为空):\n" + "将下面的日文文本根据上述术语表的对应关系和备注翻译成中文,并且列印出使用哪些术语表:" + msn[0]
47
- }
48
- assistant = {
49
- 'role': Roles.assistant,
50
- 'content': msn[1]
51
- }
52
- messages.add_message(user)
53
- messages.add_message(assistant)
54
- return messages
 
 
55
 
56
- # 主函數
57
- @spaces.GPU(duration=120)
58
  def respond(
59
  message,
60
- history: list[tuple[str, str]],
61
  model=FILENAME,
62
- system_message=SYSTEM_MESSAGE,
63
  max_tokens=4096,
64
  temperature=0.1,
65
  top_p=0.3,
66
  top_k=40,
67
  repeat_penalty=1.1,
68
  ):
69
- global llm
70
- global llm_model
71
-
72
- if llm is None or llm_model != model:
73
- llm = initialize_llama(f"{MODEL_DIR}/{model}")
74
- llm_model = model
75
-
76
  provider = LlamaCppPythonProvider(llm)
77
  agent = LlamaCppAgent(
78
  provider,
79
- system_prompt=system_message,
80
- predefined_messages_formatter_type=MessagesFormatterType.GEMMA_2,
81
  debug_output=True
82
  )
83
-
84
  settings = provider.get_provider_default_settings()
85
  settings.temperature = temperature
86
  settings.top_k = top_k
@@ -88,9 +80,21 @@ def respond(
88
  settings.max_tokens = max_tokens
89
  settings.repeat_penalty = repeat_penalty
90
  settings.stream = True
91
-
92
- messages = process_history(history)
93
-
 
 
 
 
 
 
 
 
 
 
 
 
94
  stream = agent.get_chat_response(
95
  message,
96
  llm_sampling_settings=settings,
@@ -98,14 +102,14 @@ def respond(
98
  returns_streaming_generator=True,
99
  print_output=False
100
  )
101
-
102
  outputs = ""
103
  for output in stream:
104
  outputs += output
105
  outputs = outputs.replace(system_message, '')
 
106
  yield outputs
107
 
108
- # Gradio 接口
109
  description = """<p align="center">Defaults to Sakura-14B-Qwen2beta</p>
110
  <p><center>
111
  <a href="https://huggingface.co/SakuraLLM/Sakura-14B-Qwen2beta-v0.9.2-GGUF" target="_blank">[Sakura-14B-Qwen2beta Model]</a>
@@ -129,4 +133,4 @@ demo = gr.ChatInterface(
129
 
130
  if __name__ == "__main__":
131
  download_model(REPO_ID, FILENAME, MODEL_DIR)
132
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
 
2
  import json
3
  import subprocess
4
+ import gradio as gr
5
  from llama_cpp import Llama
6
  from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
7
  from llama_cpp_agent.providers import LlamaCppPythonProvider
8
  from llama_cpp_agent.chat_history import BasicChatHistory
9
  from llama_cpp_agent.chat_history.messages import Roles
 
10
  from huggingface_hub import hf_hub_download
 
 
 
 
 
11
 
12
+ # 環境變量或配置文件管理
13
+ REPO_ID = "SakuraLLM/Sakura-14B-Qwen2beta-v0.9.2-GGUF"
14
+ FILENAME = "sakura-14b-qwen2beta-v0.9.2-q4km.gguf"
 
15
  MODEL_DIR = "./models"
16
+ DEFAULT_SYSTEM_MESSAGE = ("你是一个轻小说翻译模型,可以流畅通顺地使用给定的术语表以日本轻小说的风格将日文翻译成简体中文,"
17
+ "并联系上下文正确使用人称代词,注意不要混淆使役态和被动态的主语和宾语,不要擅自添加原文中没有的代词,"
18
+ "也不要擅自增加或减少换行。")
19
 
20
+ llm = None
21
+ llm_model = None
 
 
22
 
23
+ def download_model(repo_id, filename, local_dir):
24
+ """下載模型"""
25
+ try:
26
+ hf_hub_download(
27
+ repo_id=repo_id,
28
+ filename=filename,
29
+ local_dir=local_dir
30
+ )
31
+ except Exception as e:
32
+ print(f"下載模型失敗: {e}")
33
 
34
+ def load_model(model_path, model):
35
+ """加載模型"""
36
+ global llm
37
+ global llm_model
38
+
39
+ if llm is None or llm_model != model:
40
+ try:
41
+ llm = Llama(
42
+ model_path=model_path,
43
+ flash_attn=True,
44
+ n_gpu_layers=81,
45
+ n_batch=1024,
46
+ n_ctx=8192,
47
+ )
48
+ llm_model = model
49
+ except Exception as e:
50
+ print(f"加載模型失敗: {e}")
51
 
 
 
52
  def respond(
53
  message,
54
+ history,
55
  model=FILENAME,
56
+ system_message=DEFAULT_SYSTEM_MESSAGE,
57
  max_tokens=4096,
58
  temperature=0.1,
59
  top_p=0.3,
60
  top_k=40,
61
  repeat_penalty=1.1,
62
  ):
63
+ """處理回應"""
64
+ chat_template = MessagesFormatterType.GEMMA_2
65
+
66
+ load_model(f"{MODEL_DIR}/{model}", model)
67
+
 
 
68
  provider = LlamaCppPythonProvider(llm)
69
  agent = LlamaCppAgent(
70
  provider,
71
+ system_prompt=f"{system_message}",
72
+ predefined_messages_formatter_type=chat_template,
73
  debug_output=True
74
  )
75
+
76
  settings = provider.get_provider_default_settings()
77
  settings.temperature = temperature
78
  settings.top_k = top_k
 
80
  settings.max_tokens = max_tokens
81
  settings.repeat_penalty = repeat_penalty
82
  settings.stream = True
83
+
84
+ messages = BasicChatHistory()
85
+
86
+ for msn in history:
87
+ user = {
88
+ 'role': Roles.user,
89
+ 'content': "根据以下术语表(可以为空):\n" + "将下面的日文文本根据上述术语表���对应关系和备注翻译成中文,并且列印出使用哪些术语表:" + msn[0]
90
+ }
91
+ assistant = {
92
+ 'role': Roles.assistant,
93
+ 'content': msn[1]
94
+ }
95
+ messages.add_message(user)
96
+ messages.add_message(assistant)
97
+
98
  stream = agent.get_chat_response(
99
  message,
100
  llm_sampling_settings=settings,
 
102
  returns_streaming_generator=True,
103
  print_output=False
104
  )
105
+
106
  outputs = ""
107
  for output in stream:
108
  outputs += output
109
  outputs = outputs.replace(system_message, '')
110
+
111
  yield outputs
112
 
 
113
  description = """<p align="center">Defaults to Sakura-14B-Qwen2beta</p>
114
  <p><center>
115
  <a href="https://huggingface.co/SakuraLLM/Sakura-14B-Qwen2beta-v0.9.2-GGUF" target="_blank">[Sakura-14B-Qwen2beta Model]</a>
 
133
 
134
  if __name__ == "__main__":
135
  download_model(REPO_ID, FILENAME, MODEL_DIR)
136
+ demo.launch()