Spaces:
Runtime error
Runtime error
JohnSmith9982
commited on
Commit
•
7b5a1c0
1
Parent(s):
cecb277
Upload 38 files
Browse files- CITATION.cff +20 -0
- ChuanhuChatbot.py +159 -128
- Dockerfile +2 -1
- README.md +105 -13
- assets/custom.css +125 -31
- assets/custom.js +208 -54
- config_example.json +31 -0
- configs/ds_config_chatbot.json +17 -0
- modules/__init__.py +0 -0
- modules/__pycache__/__init__.cpython-311.pyc +0 -0
- modules/__pycache__/__init__.cpython-39.pyc +0 -0
- modules/__pycache__/base_model.cpython-311.pyc +0 -0
- modules/__pycache__/base_model.cpython-39.pyc +0 -0
- modules/__pycache__/config.cpython-311.pyc +0 -0
- modules/__pycache__/config.cpython-39.pyc +0 -0
- modules/__pycache__/llama_func.cpython-311.pyc +0 -0
- modules/__pycache__/models.cpython-311.pyc +0 -0
- modules/base_model.py +547 -0
- modules/config.py +55 -34
- modules/llama_func.py +75 -46
- modules/models.py +586 -0
- modules/overwrites.py +55 -17
- modules/presets.py +82 -49
- modules/shared.py +3 -3
- modules/utils.py +115 -118
- requirements.txt +1 -0
- requirements_advanced.txt +7 -0
- run_Linux.sh +8 -2
- run_macOS.command +8 -2
CITATION.cff
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cff-version: 1.2.0
|
2 |
+
title: ChuanhuChatGPT
|
3 |
+
message: >-
|
4 |
+
If you use this software, please cite it using these
|
5 |
+
metadata.
|
6 |
+
type: software
|
7 |
+
authors:
|
8 |
+
- given-names: Chuanhu
|
9 |
+
orcid: https://orcid.org/0000-0001-8954-8598
|
10 |
+
- given-names: MZhao
|
11 |
+
orcid: https://orcid.org/0000-0003-2298-6213
|
12 |
+
- given-names: Keldos
|
13 |
+
orcid: https://orcid.org/0009-0005-0357-272X
|
14 |
+
repository-code: 'https://github.com/GaiZhenbiao/ChuanhuChatGPT'
|
15 |
+
url: 'https://github.com/GaiZhenbiao/ChuanhuChatGPT'
|
16 |
+
abstract: Provided a light and easy to use interface for ChatGPT API
|
17 |
+
license: GPL-3.0
|
18 |
+
commit: bd0034c37e5af6a90bd9c2f7dd073f6cd27c61af
|
19 |
+
version: '20230405'
|
20 |
+
date-released: '2023-04-05'
|
ChuanhuChatbot.py
CHANGED
@@ -10,31 +10,32 @@ from modules.config import *
|
|
10 |
from modules.utils import *
|
11 |
from modules.presets import *
|
12 |
from modules.overwrites import *
|
13 |
-
from modules.
|
14 |
-
from modules.openai_func import get_usage
|
15 |
|
|
|
16 |
gr.Chatbot.postprocess = postprocess
|
17 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
18 |
|
19 |
with open("assets/custom.css", "r", encoding="utf-8") as f:
|
20 |
customCSS = f.read()
|
21 |
|
|
|
|
|
|
|
22 |
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
23 |
user_name = gr.State("")
|
24 |
-
history = gr.State([])
|
25 |
-
token_count = gr.State([])
|
26 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
27 |
-
user_api_key = gr.State(my_api_key)
|
28 |
user_question = gr.State("")
|
29 |
-
|
|
|
|
|
30 |
topic = gr.State("未命名对话历史记录")
|
31 |
|
32 |
with gr.Row():
|
33 |
-
|
34 |
-
gr.HTML(title)
|
35 |
-
user_info = gr.Markdown(value="", elem_id="user_info")
|
36 |
-
gr.HTML('<center><a href="https://huggingface.co/spaces/JohnSmith9982/ChuanhuChatGPT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a></center>')
|
37 |
status_display = gr.Markdown(get_geoip(), elem_id="status_display")
|
|
|
|
|
38 |
|
39 |
# https://github.com/gradio-app/gradio/pull/3296
|
40 |
def create_greeting(request: gr.Request):
|
@@ -50,14 +51,14 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
50 |
with gr.Row():
|
51 |
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
|
52 |
with gr.Row():
|
53 |
-
with gr.Column(scale=12):
|
54 |
user_input = gr.Textbox(
|
55 |
elem_id="user_input_tb",
|
56 |
show_label=False, placeholder="在这里输入"
|
57 |
).style(container=False)
|
58 |
-
with gr.Column(min_width=
|
59 |
-
submitBtn = gr.Button("
|
60 |
-
cancelBtn = gr.Button("
|
61 |
with gr.Row():
|
62 |
emptyBtn = gr.Button(
|
63 |
"🧹 新的对话",
|
@@ -65,37 +66,41 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
65 |
retryBtn = gr.Button("🔄 重新生成")
|
66 |
delFirstBtn = gr.Button("🗑️ 删除最旧对话")
|
67 |
delLastBtn = gr.Button("🗑️ 删除最新对话")
|
68 |
-
reduceTokenBtn = gr.Button("♻️ 总结对话")
|
69 |
|
70 |
with gr.Column():
|
71 |
with gr.Column(min_width=50, scale=1):
|
72 |
-
with gr.Tab(label="
|
73 |
keyTxt = gr.Textbox(
|
74 |
show_label=True,
|
75 |
placeholder=f"OpenAI API-key...",
|
76 |
-
value=hide_middle_chars(
|
77 |
type="password",
|
78 |
visible=not HIDE_MY_KEY,
|
79 |
label="API-Key",
|
80 |
)
|
81 |
if multi_api_key:
|
82 |
-
usageTxt = gr.Markdown("多账号模式已开启,无需输入key,可直接开始对话", elem_id="usage_display")
|
83 |
else:
|
84 |
-
usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display")
|
85 |
model_select_dropdown = gr.Dropdown(
|
86 |
-
label="选择模型", choices=MODELS, multiselect=False, value=MODELS[
|
87 |
)
|
88 |
-
|
89 |
-
label="
|
90 |
)
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
92 |
language_select_dropdown = gr.Dropdown(
|
93 |
label="选择回复语言(针对搜索&索引功能)",
|
94 |
choices=REPLY_LANGUAGES,
|
95 |
multiselect=False,
|
96 |
value=REPLY_LANGUAGES[0],
|
97 |
)
|
98 |
-
index_files = gr.Files(label="上传索引文件", type="file"
|
99 |
two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
|
100 |
# TODO: 公式ocr
|
101 |
# formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
|
@@ -105,7 +110,7 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
105 |
show_label=True,
|
106 |
placeholder=f"在这里输入System Prompt...",
|
107 |
label="System prompt",
|
108 |
-
value=
|
109 |
lines=10,
|
110 |
).style(container=False)
|
111 |
with gr.Accordion(label="加载Prompt模板", open=True):
|
@@ -161,27 +166,87 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
161 |
|
162 |
with gr.Tab(label="高级"):
|
163 |
gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
|
164 |
-
|
165 |
-
|
166 |
with gr.Accordion("参数", open=False):
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
minimum=-0,
|
169 |
maximum=1.0,
|
170 |
value=1.0,
|
171 |
step=0.05,
|
172 |
interactive=True,
|
173 |
-
label="
|
174 |
)
|
175 |
-
|
176 |
-
minimum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
maximum=2.0,
|
178 |
-
value=
|
179 |
-
step=0.
|
180 |
interactive=True,
|
181 |
-
label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
)
|
183 |
|
184 |
-
with gr.Accordion("网络设置", open=False
|
185 |
# 优先展示自定义的api_host
|
186 |
apihostTxt = gr.Textbox(
|
187 |
show_label=True,
|
@@ -199,27 +264,22 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
199 |
lines=2,
|
200 |
)
|
201 |
changeProxyBtn = gr.Button("🔄 设置代理地址")
|
|
|
202 |
|
203 |
-
gr.Markdown(
|
204 |
-
gr.HTML(
|
205 |
chatgpt_predict_args = dict(
|
206 |
fn=predict,
|
207 |
inputs=[
|
208 |
-
|
209 |
-
systemPromptTxt,
|
210 |
-
history,
|
211 |
user_question,
|
212 |
chatbot,
|
213 |
-
token_count,
|
214 |
-
top_p,
|
215 |
-
temperature,
|
216 |
use_streaming_checkbox,
|
217 |
-
model_select_dropdown,
|
218 |
use_websearch_checkbox,
|
219 |
index_files,
|
220 |
language_select_dropdown,
|
221 |
],
|
222 |
-
outputs=[chatbot,
|
223 |
show_progress=True,
|
224 |
)
|
225 |
|
@@ -243,12 +303,18 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
243 |
)
|
244 |
|
245 |
get_usage_args = dict(
|
246 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
)
|
248 |
|
249 |
|
250 |
# Chatbot
|
251 |
-
cancelBtn.click(
|
252 |
|
253 |
user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
254 |
user_input.submit(**get_usage_args)
|
@@ -256,9 +322,12 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
256 |
submitBtn.click(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
257 |
submitBtn.click(**get_usage_args)
|
258 |
|
|
|
|
|
259 |
emptyBtn.click(
|
260 |
-
|
261 |
-
|
|
|
262 |
show_progress=True,
|
263 |
)
|
264 |
emptyBtn.click(**reset_textbox_args)
|
@@ -266,61 +335,42 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
266 |
retryBtn.click(**start_outputing_args).then(
|
267 |
retry,
|
268 |
[
|
269 |
-
|
270 |
-
systemPromptTxt,
|
271 |
-
history,
|
272 |
chatbot,
|
273 |
-
token_count,
|
274 |
-
top_p,
|
275 |
-
temperature,
|
276 |
use_streaming_checkbox,
|
277 |
-
|
|
|
278 |
language_select_dropdown,
|
279 |
],
|
280 |
-
[chatbot,
|
281 |
show_progress=True,
|
282 |
).then(**end_outputing_args)
|
283 |
retryBtn.click(**get_usage_args)
|
284 |
|
285 |
delFirstBtn.click(
|
286 |
delete_first_conversation,
|
287 |
-
[
|
288 |
-
[
|
289 |
)
|
290 |
|
291 |
delLastBtn.click(
|
292 |
delete_last_conversation,
|
293 |
-
[
|
294 |
-
[chatbot,
|
295 |
-
show_progress=
|
296 |
)
|
297 |
|
298 |
-
reduceTokenBtn.click(
|
299 |
-
reduce_token_size,
|
300 |
-
[
|
301 |
-
user_api_key,
|
302 |
-
systemPromptTxt,
|
303 |
-
history,
|
304 |
-
chatbot,
|
305 |
-
token_count,
|
306 |
-
top_p,
|
307 |
-
temperature,
|
308 |
-
gr.State(sum(token_count.value[-4:])),
|
309 |
-
model_select_dropdown,
|
310 |
-
language_select_dropdown,
|
311 |
-
],
|
312 |
-
[chatbot, history, status_display, token_count],
|
313 |
-
show_progress=True,
|
314 |
-
)
|
315 |
-
reduceTokenBtn.click(**get_usage_args)
|
316 |
-
|
317 |
two_column.change(update_doc_config, [two_column], None)
|
318 |
|
319 |
-
#
|
320 |
-
keyTxt.change(
|
321 |
keyTxt.submit(**get_usage_args)
|
|
|
|
|
|
|
322 |
|
323 |
# Template
|
|
|
324 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
325 |
templateFileSelectDropdown.change(
|
326 |
load_template,
|
@@ -338,31 +388,33 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
338 |
# S&L
|
339 |
saveHistoryBtn.click(
|
340 |
save_chat_history,
|
341 |
-
[
|
342 |
downloadFile,
|
343 |
show_progress=True,
|
344 |
)
|
345 |
saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
346 |
exportMarkdownBtn.click(
|
347 |
export_markdown,
|
348 |
-
[
|
349 |
downloadFile,
|
350 |
show_progress=True,
|
351 |
)
|
352 |
historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
353 |
-
historyFileSelectDropdown.change(
|
354 |
-
|
355 |
-
[historyFileSelectDropdown, systemPromptTxt, history, chatbot, user_name],
|
356 |
-
[saveFileName, systemPromptTxt, history, chatbot],
|
357 |
-
show_progress=True,
|
358 |
-
)
|
359 |
-
downloadFile.change(
|
360 |
-
load_chat_history,
|
361 |
-
[downloadFile, systemPromptTxt, history, chatbot, user_name],
|
362 |
-
[saveFileName, systemPromptTxt, history, chatbot],
|
363 |
-
)
|
364 |
|
365 |
# Advanced
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
default_btn.click(
|
367 |
reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
|
368 |
)
|
@@ -389,35 +441,14 @@ demo.title = "川虎ChatGPT 🚀"
|
|
389 |
|
390 |
if __name__ == "__main__":
|
391 |
reload_javascript()
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
server_name="0.0.0.0",
|
404 |
-
server_port=7860,
|
405 |
-
share=False,
|
406 |
-
favicon_path="./assets/favicon.ico",
|
407 |
-
)
|
408 |
-
# if not running in Docker
|
409 |
-
else:
|
410 |
-
if authflag:
|
411 |
-
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
412 |
-
share=False,
|
413 |
-
auth=auth_list,
|
414 |
-
favicon_path="./assets/favicon.ico",
|
415 |
-
inbrowser=True,
|
416 |
-
)
|
417 |
-
else:
|
418 |
-
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
419 |
-
share=False, favicon_path="./assets/favicon.ico", inbrowser=True
|
420 |
-
) # 改为 share=True 可以创建公开分享链接
|
421 |
-
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
|
422 |
-
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
|
423 |
-
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
|
|
|
10 |
from modules.utils import *
|
11 |
from modules.presets import *
|
12 |
from modules.overwrites import *
|
13 |
+
from modules.models import get_model
|
|
|
14 |
|
15 |
+
gr.Chatbot._postprocess_chat_messages = postprocess_chat_messages
|
16 |
gr.Chatbot.postprocess = postprocess
|
17 |
PromptHelper.compact_text_chunks = compact_text_chunks
|
18 |
|
19 |
with open("assets/custom.css", "r", encoding="utf-8") as f:
|
20 |
customCSS = f.read()
|
21 |
|
22 |
+
def create_new_model():
|
23 |
+
return get_model(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key)[0]
|
24 |
+
|
25 |
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
26 |
user_name = gr.State("")
|
|
|
|
|
27 |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
|
|
|
28 |
user_question = gr.State("")
|
29 |
+
user_api_key = gr.State(my_api_key)
|
30 |
+
current_model = gr.State(create_new_model)
|
31 |
+
|
32 |
topic = gr.State("未命名对话历史记录")
|
33 |
|
34 |
with gr.Row():
|
35 |
+
gr.HTML(CHUANHU_TITLE, elem_id="app_title")
|
|
|
|
|
|
|
36 |
status_display = gr.Markdown(get_geoip(), elem_id="status_display")
|
37 |
+
with gr.Row(elem_id="float_display"):
|
38 |
+
user_info = gr.Markdown(value="getting user info...", elem_id="user_info")
|
39 |
|
40 |
# https://github.com/gradio-app/gradio/pull/3296
|
41 |
def create_greeting(request: gr.Request):
|
|
|
51 |
with gr.Row():
|
52 |
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
|
53 |
with gr.Row():
|
54 |
+
with gr.Column(min_width=225, scale=12):
|
55 |
user_input = gr.Textbox(
|
56 |
elem_id="user_input_tb",
|
57 |
show_label=False, placeholder="在这里输入"
|
58 |
).style(container=False)
|
59 |
+
with gr.Column(min_width=42, scale=1):
|
60 |
+
submitBtn = gr.Button(value="", variant="primary", elem_id="submit_btn")
|
61 |
+
cancelBtn = gr.Button(value="", variant="secondary", visible=False, elem_id="cancel_btn")
|
62 |
with gr.Row():
|
63 |
emptyBtn = gr.Button(
|
64 |
"🧹 新的对话",
|
|
|
66 |
retryBtn = gr.Button("🔄 重新生成")
|
67 |
delFirstBtn = gr.Button("🗑️ 删除最旧对话")
|
68 |
delLastBtn = gr.Button("🗑️ 删除最新对话")
|
|
|
69 |
|
70 |
with gr.Column():
|
71 |
with gr.Column(min_width=50, scale=1):
|
72 |
+
with gr.Tab(label="模型"):
|
73 |
keyTxt = gr.Textbox(
|
74 |
show_label=True,
|
75 |
placeholder=f"OpenAI API-key...",
|
76 |
+
value=hide_middle_chars(user_api_key.value),
|
77 |
type="password",
|
78 |
visible=not HIDE_MY_KEY,
|
79 |
label="API-Key",
|
80 |
)
|
81 |
if multi_api_key:
|
82 |
+
usageTxt = gr.Markdown("多账号模式已开启,无需输入key,可直接开始对话", elem_id="usage_display", elem_classes="insert_block")
|
83 |
else:
|
84 |
+
usageTxt = gr.Markdown("**发送消息** 或 **提交key** 以显示额度", elem_id="usage_display", elem_classes="insert_block")
|
85 |
model_select_dropdown = gr.Dropdown(
|
86 |
+
label="选择模型", choices=MODELS, multiselect=False, value=MODELS[DEFAULT_MODEL], interactive=True
|
87 |
)
|
88 |
+
lora_select_dropdown = gr.Dropdown(
|
89 |
+
label="选择LoRA模型", choices=[], multiselect=False, interactive=True, visible=False
|
90 |
)
|
91 |
+
with gr.Row():
|
92 |
+
use_streaming_checkbox = gr.Checkbox(
|
93 |
+
label="实时传输回答", value=True, visible=ENABLE_STREAMING_OPTION
|
94 |
+
)
|
95 |
+
single_turn_checkbox = gr.Checkbox(label="单轮对话", value=False)
|
96 |
+
use_websearch_checkbox = gr.Checkbox(label="使用在线搜索", value=False)
|
97 |
language_select_dropdown = gr.Dropdown(
|
98 |
label="选择回复语言(针对搜索&索引功能)",
|
99 |
choices=REPLY_LANGUAGES,
|
100 |
multiselect=False,
|
101 |
value=REPLY_LANGUAGES[0],
|
102 |
)
|
103 |
+
index_files = gr.Files(label="上传索引文件", type="file")
|
104 |
two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
|
105 |
# TODO: 公式ocr
|
106 |
# formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
|
|
|
110 |
show_label=True,
|
111 |
placeholder=f"在这里输入System Prompt...",
|
112 |
label="System prompt",
|
113 |
+
value=INITIAL_SYSTEM_PROMPT,
|
114 |
lines=10,
|
115 |
).style(container=False)
|
116 |
with gr.Accordion(label="加载Prompt模板", open=True):
|
|
|
166 |
|
167 |
with gr.Tab(label="高级"):
|
168 |
gr.Markdown("# ⚠️ 务必谨慎更改 ⚠️\n\n如果无法使用请恢复默认设置")
|
169 |
+
gr.HTML(APPEARANCE_SWITCHER, elem_classes="insert_block")
|
|
|
170 |
with gr.Accordion("参数", open=False):
|
171 |
+
temperature_slider = gr.Slider(
|
172 |
+
minimum=-0,
|
173 |
+
maximum=2.0,
|
174 |
+
value=1.0,
|
175 |
+
step=0.1,
|
176 |
+
interactive=True,
|
177 |
+
label="temperature",
|
178 |
+
)
|
179 |
+
top_p_slider = gr.Slider(
|
180 |
minimum=-0,
|
181 |
maximum=1.0,
|
182 |
value=1.0,
|
183 |
step=0.05,
|
184 |
interactive=True,
|
185 |
+
label="top-p",
|
186 |
)
|
187 |
+
n_choices_slider = gr.Slider(
|
188 |
+
minimum=1,
|
189 |
+
maximum=10,
|
190 |
+
value=1,
|
191 |
+
step=1,
|
192 |
+
interactive=True,
|
193 |
+
label="n choices",
|
194 |
+
)
|
195 |
+
stop_sequence_txt = gr.Textbox(
|
196 |
+
show_label=True,
|
197 |
+
placeholder=f"在这里输入停止符,用英文逗号隔开...",
|
198 |
+
label="stop",
|
199 |
+
value="",
|
200 |
+
lines=1,
|
201 |
+
)
|
202 |
+
max_context_length_slider = gr.Slider(
|
203 |
+
minimum=1,
|
204 |
+
maximum=32768,
|
205 |
+
value=2000,
|
206 |
+
step=1,
|
207 |
+
interactive=True,
|
208 |
+
label="max context",
|
209 |
+
)
|
210 |
+
max_generation_slider = gr.Slider(
|
211 |
+
minimum=1,
|
212 |
+
maximum=32768,
|
213 |
+
value=1000,
|
214 |
+
step=1,
|
215 |
+
interactive=True,
|
216 |
+
label="max generations",
|
217 |
+
)
|
218 |
+
presence_penalty_slider = gr.Slider(
|
219 |
+
minimum=-2.0,
|
220 |
maximum=2.0,
|
221 |
+
value=0.0,
|
222 |
+
step=0.01,
|
223 |
interactive=True,
|
224 |
+
label="presence penalty",
|
225 |
+
)
|
226 |
+
frequency_penalty_slider = gr.Slider(
|
227 |
+
minimum=-2.0,
|
228 |
+
maximum=2.0,
|
229 |
+
value=0.0,
|
230 |
+
step=0.01,
|
231 |
+
interactive=True,
|
232 |
+
label="frequency penalty",
|
233 |
+
)
|
234 |
+
logit_bias_txt = gr.Textbox(
|
235 |
+
show_label=True,
|
236 |
+
placeholder=f"word:likelihood",
|
237 |
+
label="logit bias",
|
238 |
+
value="",
|
239 |
+
lines=1,
|
240 |
+
)
|
241 |
+
user_identifier_txt = gr.Textbox(
|
242 |
+
show_label=True,
|
243 |
+
placeholder=f"用于定位滥用行为",
|
244 |
+
label="用户名",
|
245 |
+
value=user_name.value,
|
246 |
+
lines=1,
|
247 |
)
|
248 |
|
249 |
+
with gr.Accordion("网络设置", open=False):
|
250 |
# 优先展示自定义的api_host
|
251 |
apihostTxt = gr.Textbox(
|
252 |
show_label=True,
|
|
|
264 |
lines=2,
|
265 |
)
|
266 |
changeProxyBtn = gr.Button("🔄 设置代理地址")
|
267 |
+
default_btn = gr.Button("🔙 恢复默认设置")
|
268 |
|
269 |
+
gr.Markdown(CHUANHU_DESCRIPTION)
|
270 |
+
gr.HTML(FOOTER.format(versions=versions_html()), elem_id="footer")
|
271 |
chatgpt_predict_args = dict(
|
272 |
fn=predict,
|
273 |
inputs=[
|
274 |
+
current_model,
|
|
|
|
|
275 |
user_question,
|
276 |
chatbot,
|
|
|
|
|
|
|
277 |
use_streaming_checkbox,
|
|
|
278 |
use_websearch_checkbox,
|
279 |
index_files,
|
280 |
language_select_dropdown,
|
281 |
],
|
282 |
+
outputs=[chatbot, status_display],
|
283 |
show_progress=True,
|
284 |
)
|
285 |
|
|
|
303 |
)
|
304 |
|
305 |
get_usage_args = dict(
|
306 |
+
fn=billing_info, inputs=[current_model], outputs=[usageTxt], show_progress=False
|
307 |
+
)
|
308 |
+
|
309 |
+
load_history_from_file_args = dict(
|
310 |
+
fn=load_chat_history,
|
311 |
+
inputs=[current_model, historyFileSelectDropdown, chatbot, user_name],
|
312 |
+
outputs=[saveFileName, systemPromptTxt, chatbot]
|
313 |
)
|
314 |
|
315 |
|
316 |
# Chatbot
|
317 |
+
cancelBtn.click(interrupt, [current_model], [])
|
318 |
|
319 |
user_input.submit(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
320 |
user_input.submit(**get_usage_args)
|
|
|
322 |
submitBtn.click(**transfer_input_args).then(**chatgpt_predict_args).then(**end_outputing_args)
|
323 |
submitBtn.click(**get_usage_args)
|
324 |
|
325 |
+
index_files.change(handle_file_upload, [current_model, index_files, chatbot], [index_files, chatbot, status_display])
|
326 |
+
|
327 |
emptyBtn.click(
|
328 |
+
reset,
|
329 |
+
inputs=[current_model],
|
330 |
+
outputs=[chatbot, status_display],
|
331 |
show_progress=True,
|
332 |
)
|
333 |
emptyBtn.click(**reset_textbox_args)
|
|
|
335 |
retryBtn.click(**start_outputing_args).then(
|
336 |
retry,
|
337 |
[
|
338 |
+
current_model,
|
|
|
|
|
339 |
chatbot,
|
|
|
|
|
|
|
340 |
use_streaming_checkbox,
|
341 |
+
use_websearch_checkbox,
|
342 |
+
index_files,
|
343 |
language_select_dropdown,
|
344 |
],
|
345 |
+
[chatbot, status_display],
|
346 |
show_progress=True,
|
347 |
).then(**end_outputing_args)
|
348 |
retryBtn.click(**get_usage_args)
|
349 |
|
350 |
delFirstBtn.click(
|
351 |
delete_first_conversation,
|
352 |
+
[current_model],
|
353 |
+
[status_display],
|
354 |
)
|
355 |
|
356 |
delLastBtn.click(
|
357 |
delete_last_conversation,
|
358 |
+
[current_model, chatbot],
|
359 |
+
[chatbot, status_display],
|
360 |
+
show_progress=False
|
361 |
)
|
362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
two_column.change(update_doc_config, [two_column], None)
|
364 |
|
365 |
+
# LLM Models
|
366 |
+
keyTxt.change(set_key, [current_model, keyTxt], [user_api_key, status_display]).then(**get_usage_args)
|
367 |
keyTxt.submit(**get_usage_args)
|
368 |
+
single_turn_checkbox.change(set_single_turn, [current_model, single_turn_checkbox], None)
|
369 |
+
model_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, user_api_key, temperature_slider, top_p_slider, systemPromptTxt], [current_model, status_display, lora_select_dropdown], show_progress=True)
|
370 |
+
lora_select_dropdown.change(get_model, [model_select_dropdown, lora_select_dropdown, user_api_key, temperature_slider, top_p_slider, systemPromptTxt], [current_model, status_display], show_progress=True)
|
371 |
|
372 |
# Template
|
373 |
+
systemPromptTxt.change(set_system_prompt, [current_model, systemPromptTxt], None)
|
374 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
375 |
templateFileSelectDropdown.change(
|
376 |
load_template,
|
|
|
388 |
# S&L
|
389 |
saveHistoryBtn.click(
|
390 |
save_chat_history,
|
391 |
+
[current_model, saveFileName, chatbot, user_name],
|
392 |
downloadFile,
|
393 |
show_progress=True,
|
394 |
)
|
395 |
saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
396 |
exportMarkdownBtn.click(
|
397 |
export_markdown,
|
398 |
+
[current_model, saveFileName, chatbot, user_name],
|
399 |
downloadFile,
|
400 |
show_progress=True,
|
401 |
)
|
402 |
historyRefreshBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
|
403 |
+
historyFileSelectDropdown.change(**load_history_from_file_args)
|
404 |
+
downloadFile.change(**load_history_from_file_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
|
406 |
# Advanced
|
407 |
+
max_context_length_slider.change(set_token_upper_limit, [current_model, max_context_length_slider], None)
|
408 |
+
temperature_slider.change(set_temperature, [current_model, temperature_slider], None)
|
409 |
+
top_p_slider.change(set_top_p, [current_model, top_p_slider], None)
|
410 |
+
n_choices_slider.change(set_n_choices, [current_model, n_choices_slider], None)
|
411 |
+
stop_sequence_txt.change(set_stop_sequence, [current_model, stop_sequence_txt], None)
|
412 |
+
max_generation_slider.change(set_max_tokens, [current_model, max_generation_slider], None)
|
413 |
+
presence_penalty_slider.change(set_presence_penalty, [current_model, presence_penalty_slider], None)
|
414 |
+
frequency_penalty_slider.change(set_frequency_penalty, [current_model, frequency_penalty_slider], None)
|
415 |
+
logit_bias_txt.change(set_logit_bias, [current_model, logit_bias_txt], None)
|
416 |
+
user_identifier_txt.change(set_user_identifier, [current_model, user_identifier_txt], None)
|
417 |
+
|
418 |
default_btn.click(
|
419 |
reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
|
420 |
)
|
|
|
441 |
|
442 |
if __name__ == "__main__":
|
443 |
reload_javascript()
|
444 |
+
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(
|
445 |
+
server_name=server_name,
|
446 |
+
server_port=server_port,
|
447 |
+
share=share,
|
448 |
+
auth=auth_list if authflag else None,
|
449 |
+
favicon_path="./assets/favicon.ico",
|
450 |
+
inbrowser=not dockerflag, # 禁止在docker下开启inbrowser
|
451 |
+
)
|
452 |
+
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
|
453 |
+
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", server_port=7860,auth=("在这里填写用户名", "在这里填写密码")) # 可设置用户名与密码
|
454 |
+
# demo.queue(concurrency_count=CONCURRENT_COUNT).launch(auth=("在这里填写用户名", "在这里填写密码")) # 适合Nginx反向代理
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
FROM python:3.9 as builder
|
2 |
RUN apt-get update && apt-get install -y build-essential
|
3 |
COPY requirements.txt .
|
|
|
4 |
RUN pip install --user -r requirements.txt
|
|
|
5 |
|
6 |
FROM python:3.9
|
7 |
MAINTAINER iskoldt
|
@@ -9,6 +11,5 @@ COPY --from=builder /root/.local /root/.local
|
|
9 |
ENV PATH=/root/.local/bin:$PATH
|
10 |
COPY . /app
|
11 |
WORKDIR /app
|
12 |
-
ENV my_api_key empty
|
13 |
ENV dockerrun yes
|
14 |
CMD ["python3", "-u", "ChuanhuChatbot.py", "2>&1", "|", "tee", "/var/log/application.log"]
|
|
|
1 |
FROM python:3.9 as builder
|
2 |
RUN apt-get update && apt-get install -y build-essential
|
3 |
COPY requirements.txt .
|
4 |
+
COPY requirements_advanced.txt .
|
5 |
RUN pip install --user -r requirements.txt
|
6 |
+
# RUN pip install --user -r requirements_advanced.txt
|
7 |
|
8 |
FROM python:3.9
|
9 |
MAINTAINER iskoldt
|
|
|
11 |
ENV PATH=/root/.local/bin:$PATH
|
12 |
COPY . /app
|
13 |
WORKDIR /app
|
|
|
14 |
ENV dockerrun yes
|
15 |
CMD ["python3", "-u", "ChuanhuChatbot.py", "2>&1", "|", "tee", "/var/log/application.log"]
|
README.md
CHANGED
@@ -1,13 +1,105 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h1 align="center">川虎 Chat 🐯 Chuanhu Chat</h1>
|
2 |
+
<div align="center">
|
3 |
+
<a href="https://github.com/GaiZhenBiao/ChuanhuChatGPT">
|
4 |
+
<img src="https://user-images.githubusercontent.com/70903329/227087087-93b37d64-7dc3-4738-a518-c1cf05591c8a.png" alt="Logo" height="156">
|
5 |
+
</a>
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<h3>为ChatGPT/ChatGLM/LLaMA等多种LLM提供了一个轻快好用的Web图形界面</h3>
|
9 |
+
<p align="center">
|
10 |
+
<a href="https://github.com/GaiZhenbiao/ChuanhuChatGPT/blob/main/LICENSE">
|
11 |
+
<img alt="Tests Passing" src="https://img.shields.io/github/license/GaiZhenbiao/ChuanhuChatGPT" />
|
12 |
+
</a>
|
13 |
+
<a href="https://gradio.app/">
|
14 |
+
<img alt="GitHub Contributors" src="https://img.shields.io/badge/Base-Gradio-fb7d1a?style=flat" />
|
15 |
+
</a>
|
16 |
+
<a href="https://t.me/tkdifferent">
|
17 |
+
<img alt="GitHub pull requests" src="https://img.shields.io/badge/Telegram-Group-blue.svg?logo=telegram" />
|
18 |
+
</a>
|
19 |
+
<p>
|
20 |
+
实时回复 / 无限对话 / 保存对话 / 预设Prompt集 / 联网搜索 / 根据文件回答 <br />
|
21 |
+
渲染LaTeX / 渲染表格 / 代码高亮 / 自动亮暗色切换 / 自适应界面 / “小而美”的体验 <br />
|
22 |
+
自定义api-Host / 多参数可调 / 多API Key均衡负载 / 多用户显示 / 适配GPT-4 / 支持本地部署LLM
|
23 |
+
</p>
|
24 |
+
<a href="https://www.bilibili.com/video/BV1mo4y1r7eE"><strong>视频教程</strong></a>
|
25 |
+
·
|
26 |
+
<a href="https://www.bilibili.com/video/BV1184y1w7aP"><strong>2.0介绍视频</strong></a>
|
27 |
+
||
|
28 |
+
<a href="https://huggingface.co/spaces/JohnSmith9982/ChuanhuChatGPT"><strong>在线体验</strong></a>
|
29 |
+
·
|
30 |
+
<a href="https://huggingface.co/login?next=%2Fspaces%2FJohnSmith9982%2FChuanhuChatGPT%3Fduplicate%3Dtrue"><strong>一键部署</strong></a>
|
31 |
+
</p>
|
32 |
+
<p align="center">
|
33 |
+
<img alt="Animation Demo" src="https://user-images.githubusercontent.com/51039745/226255695-6b17ff1f-ea8d-464f-b69b-a7b6b68fffe8.gif" />
|
34 |
+
</p>
|
35 |
+
</p>
|
36 |
+
</div>
|
37 |
+
|
38 |
+
## 目录
|
39 |
+
|[使用技巧](#使用技巧)|[安装方式](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程)|[常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题)| [给作者买可乐🥤](#捐款) |
|
40 |
+
| ---- | ---- | ---- | --- |
|
41 |
+
|
42 |
+
## 使用技巧
|
43 |
+
|
44 |
+
- 使用System Prompt可以很有效地设定前提条件。
|
45 |
+
- 使用Prompt模板功能时,选择Prompt模板集合文件,然后从下拉菜单中选择想要的prompt。
|
46 |
+
- 如果回答不满意,可以使用`重新生成`按钮再试一次
|
47 |
+
- 对于长对话,可以使用`优化Tokens`按钮减少Tokens占用。
|
48 |
+
- 输入框支持换行,按`shift enter`即可。
|
49 |
+
- 可以在输入框按上下箭头在输入历史之间切换
|
50 |
+
- 部署到服务器:将程序最后一句改成`demo.launch(server_name="0.0.0.0", server_port=<你的端口号>)`。
|
51 |
+
- 获取公共链接:将程序最后一句改成`demo.launch(share=True)`。注意程序必须在运行,才能通过公共链接访问。
|
52 |
+
- 在Hugging Face上使用:建议在右上角 **复制Space** 再使用,这样App反应可能会快一点。
|
53 |
+
|
54 |
+
|
55 |
+
## 安装方式、使用方式
|
56 |
+
|
57 |
+
请查看[本项目的wiki页面](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用教程)。
|
58 |
+
|
59 |
+
## 疑难杂症解决
|
60 |
+
|
61 |
+
在遇到各种问题查阅相关信息前,您可以先尝试手动拉取本项目的最新更改并更新 gradio,然后重试。步骤为:
|
62 |
+
|
63 |
+
1. 点击网页上的 `Download ZIP` 下载最新代码,或
|
64 |
+
```shell
|
65 |
+
git pull https://github.com/GaiZhenbiao/ChuanhuChatGPT.git main -f
|
66 |
+
```
|
67 |
+
2. 尝试再次安装依赖(可能本项目引入了新的依赖)
|
68 |
+
```
|
69 |
+
pip install -r requirements.txt
|
70 |
+
```
|
71 |
+
3. 更新gradio
|
72 |
+
```
|
73 |
+
pip install gradio --upgrade --force-reinstall
|
74 |
+
```
|
75 |
+
|
76 |
+
很多时候,这样就可以解决问题。
|
77 |
+
|
78 |
+
如果问题仍然存在,请查阅该页面:[常见问题](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/常见问题)
|
79 |
+
|
80 |
+
该页面列出了**几乎所有**您可能遇到的各种问题,包括如何配置代理,以及遇到问题后您该采取的措施,**请务必认真阅读**。
|
81 |
+
|
82 |
+
## 了解更多
|
83 |
+
|
84 |
+
若需了解更多信息,请查看我们的 [wiki](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki):
|
85 |
+
|
86 |
+
- [想要做出贡献?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/贡献指南)
|
87 |
+
- [项目更新情况?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/更新日志)
|
88 |
+
- [二次开发许可?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用许可)
|
89 |
+
- [如何引用项目?](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/使用许可#如何引用该项目)
|
90 |
+
|
91 |
+
## Starchart
|
92 |
+
|
93 |
+
[![Star History Chart](https://api.star-history.com/svg?repos=GaiZhenbiao/ChuanhuChatGPT&type=Date)](https://star-history.com/#GaiZhenbiao/ChuanhuChatGPT&Date)
|
94 |
+
|
95 |
+
## Contributors
|
96 |
+
|
97 |
+
<a href="https://github.com/GaiZhenbiao/ChuanhuChatGPT/graphs/contributors">
|
98 |
+
<img src="https://contrib.rocks/image?repo=GaiZhenbiao/ChuanhuChatGPT" />
|
99 |
+
</a>
|
100 |
+
|
101 |
+
## 捐款
|
102 |
+
|
103 |
+
🐯如果觉得这个软件对你有所帮助,欢迎请作者喝可乐、喝咖啡~
|
104 |
+
|
105 |
+
<img width="250" alt="image" src="https://user-images.githubusercontent.com/51039745/226920291-e8ec0b0a-400f-4c20-ac13-dafac0c3aeeb.JPG">
|
assets/custom.css
CHANGED
@@ -3,14 +3,18 @@
|
|
3 |
--chatbot-color-dark: #121111;
|
4 |
}
|
5 |
|
|
|
|
|
|
|
|
|
6 |
/* 覆盖gradio的页脚信息QAQ */
|
7 |
footer {
|
8 |
display: none !important;
|
9 |
}
|
10 |
-
#footer{
|
11 |
text-align: center;
|
12 |
}
|
13 |
-
#footer div{
|
14 |
display: inline-block;
|
15 |
}
|
16 |
#footer .versions{
|
@@ -18,16 +22,34 @@ footer {
|
|
18 |
opacity: 0.85;
|
19 |
}
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
#user_info {
|
23 |
white-space: nowrap;
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
}
|
27 |
#user_info p {
|
28 |
-
|
29 |
-
font-
|
30 |
-
|
|
|
|
|
|
|
31 |
}
|
32 |
|
33 |
/* status_display */
|
@@ -43,14 +65,18 @@ footer {
|
|
43 |
color: var(--body-text-color-subdued);
|
44 |
}
|
45 |
|
46 |
-
#
|
47 |
transition: all 0.6s;
|
48 |
}
|
|
|
|
|
|
|
49 |
|
50 |
/* usage_display */
|
51 |
-
|
52 |
position: relative;
|
53 |
margin: 0;
|
|
|
54 |
box-shadow: var(--block-shadow);
|
55 |
border-width: var(--block-border-width);
|
56 |
border-color: var(--block-border-color);
|
@@ -62,7 +88,6 @@ footer {
|
|
62 |
}
|
63 |
#usage_display p, #usage_display span {
|
64 |
margin: 0;
|
65 |
-
padding: .5em 1em;
|
66 |
font-size: .85em;
|
67 |
color: var(--body-text-color-subdued);
|
68 |
}
|
@@ -74,7 +99,7 @@ footer {
|
|
74 |
overflow: hidden;
|
75 |
}
|
76 |
.progress {
|
77 |
-
background-color: var(--block-title-background-fill)
|
78 |
height: 100%;
|
79 |
border-radius: 10px;
|
80 |
text-align: right;
|
@@ -88,38 +113,107 @@ footer {
|
|
88 |
padding-right: 10px;
|
89 |
line-height: 20px;
|
90 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
/* list */
|
92 |
ol:not(.options), ul:not(.options) {
|
93 |
padding-inline-start: 2em !important;
|
94 |
}
|
95 |
|
96 |
-
/*
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
#chuanhu_chatbot {
|
99 |
-
|
100 |
-
color: #000000 !important;
|
101 |
}
|
102 |
-
|
103 |
-
|
104 |
-
}
|
105 |
-
[data-testid = "user"] {
|
106 |
-
background-color: #95EC69 !important;
|
107 |
}
|
108 |
}
|
109 |
-
/*
|
110 |
-
@media (
|
111 |
#chuanhu_chatbot {
|
112 |
-
|
113 |
-
color: #FFFFFF !important;
|
114 |
}
|
115 |
-
|
116 |
-
|
117 |
}
|
118 |
-
[data-testid = "
|
119 |
-
|
120 |
}
|
121 |
-
|
122 |
-
|
123 |
}
|
124 |
}
|
125 |
/* 对话气泡 */
|
|
|
3 |
--chatbot-color-dark: #121111;
|
4 |
}
|
5 |
|
6 |
+
#app_title {
|
7 |
+
margin-top: 6px;
|
8 |
+
white-space: nowrap;
|
9 |
+
}
|
10 |
/* 覆盖gradio的页脚信息QAQ */
|
11 |
footer {
|
12 |
display: none !important;
|
13 |
}
|
14 |
+
#footer {
|
15 |
text-align: center;
|
16 |
}
|
17 |
+
#footer div {
|
18 |
display: inline-block;
|
19 |
}
|
20 |
#footer .versions{
|
|
|
22 |
opacity: 0.85;
|
23 |
}
|
24 |
|
25 |
+
#float_display {
|
26 |
+
position: absolute;
|
27 |
+
max-height: 30px;
|
28 |
+
}
|
29 |
+
/* user_info */
|
30 |
#user_info {
|
31 |
white-space: nowrap;
|
32 |
+
position: absolute; left: 8em; top: .2em;
|
33 |
+
z-index: var(--layer-2);
|
34 |
+
box-shadow: var(--block-shadow);
|
35 |
+
border: none; border-radius: var(--block-label-radius);
|
36 |
+
background: var(--color-accent);
|
37 |
+
padding: var(--block-label-padding);
|
38 |
+
font-size: var(--block-label-text-size); line-height: var(--line-sm);
|
39 |
+
width: auto; min-height: 30px!important;
|
40 |
+
opacity: 1;
|
41 |
+
transition: opacity 0.3s ease-in-out;
|
42 |
+
}
|
43 |
+
#user_info .wrap {
|
44 |
+
opacity: 0;
|
45 |
}
|
46 |
#user_info p {
|
47 |
+
color: white;
|
48 |
+
font-weight: var(--block-label-text-weight);
|
49 |
+
}
|
50 |
+
#user_info.hideK {
|
51 |
+
opacity: 0;
|
52 |
+
transition: opacity 1s ease-in-out;
|
53 |
}
|
54 |
|
55 |
/* status_display */
|
|
|
65 |
color: var(--body-text-color-subdued);
|
66 |
}
|
67 |
|
68 |
+
#status_display {
|
69 |
transition: all 0.6s;
|
70 |
}
|
71 |
+
#chuanhu_chatbot {
|
72 |
+
transition: height 0.3s ease;
|
73 |
+
}
|
74 |
|
75 |
/* usage_display */
|
76 |
+
.insert_block {
|
77 |
position: relative;
|
78 |
margin: 0;
|
79 |
+
padding: .5em 1em;
|
80 |
box-shadow: var(--block-shadow);
|
81 |
border-width: var(--block-border-width);
|
82 |
border-color: var(--block-border-color);
|
|
|
88 |
}
|
89 |
#usage_display p, #usage_display span {
|
90 |
margin: 0;
|
|
|
91 |
font-size: .85em;
|
92 |
color: var(--body-text-color-subdued);
|
93 |
}
|
|
|
99 |
overflow: hidden;
|
100 |
}
|
101 |
.progress {
|
102 |
+
background-color: var(--block-title-background-fill);
|
103 |
height: 100%;
|
104 |
border-radius: 10px;
|
105 |
text-align: right;
|
|
|
113 |
padding-right: 10px;
|
114 |
line-height: 20px;
|
115 |
}
|
116 |
+
|
117 |
+
.apSwitch {
|
118 |
+
top: 2px;
|
119 |
+
display: inline-block;
|
120 |
+
height: 24px;
|
121 |
+
position: relative;
|
122 |
+
width: 48px;
|
123 |
+
border-radius: 12px;
|
124 |
+
}
|
125 |
+
.apSwitch input {
|
126 |
+
display: none !important;
|
127 |
+
}
|
128 |
+
.apSlider {
|
129 |
+
background-color: var(--block-label-background-fill);
|
130 |
+
bottom: 0;
|
131 |
+
cursor: pointer;
|
132 |
+
left: 0;
|
133 |
+
position: absolute;
|
134 |
+
right: 0;
|
135 |
+
top: 0;
|
136 |
+
transition: .4s;
|
137 |
+
font-size: 18px;
|
138 |
+
border-radius: 12px;
|
139 |
+
}
|
140 |
+
.apSlider::before {
|
141 |
+
bottom: -1.5px;
|
142 |
+
left: 1px;
|
143 |
+
position: absolute;
|
144 |
+
transition: .4s;
|
145 |
+
content: "🌞";
|
146 |
+
}
|
147 |
+
input:checked + .apSlider {
|
148 |
+
background-color: var(--block-label-background-fill);
|
149 |
+
}
|
150 |
+
input:checked + .apSlider::before {
|
151 |
+
transform: translateX(23px);
|
152 |
+
content:"🌚";
|
153 |
+
}
|
154 |
+
|
155 |
+
#submit_btn, #cancel_btn {
|
156 |
+
height: 42px !important;
|
157 |
+
}
|
158 |
+
#submit_btn::before {
|
159 |
+
content: url("data:image/svg+xml, %3Csvg width='21px' height='20px' viewBox='0 0 21 20' version='1.1' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink'%3E %3Cg id='page' stroke='none' stroke-width='1' fill='none' fill-rule='evenodd'%3E %3Cg id='send' transform='translate(0.435849, 0.088463)' fill='%23FFFFFF' fill-rule='nonzero'%3E %3Cpath d='M0.579148261,0.0428666046 C0.301105539,-0.0961547561 -0.036517765,0.122307382 0.0032026237,0.420210298 L1.4927172,18.1553639 C1.5125774,18.4334066 1.79062012,18.5922882 2.04880264,18.4929872 L8.24518329,15.8913017 L11.6412765,19.7441794 C11.8597387,19.9825018 12.2370824,19.8832008 12.3165231,19.5852979 L13.9450591,13.4882182 L19.7839562,11.0255541 C20.0619989,10.8865327 20.0818591,10.4694687 19.7839562,10.3105871 L0.579148261,0.0428666046 Z M11.6138902,17.0883151 L9.85385903,14.7195502 L0.718169621,0.618812241 L12.69945,12.9346347 L11.6138902,17.0883151 Z' id='shape'%3E%3C/path%3E %3C/g%3E %3C/g%3E %3C/svg%3E");
|
160 |
+
height: 21px;
|
161 |
+
}
|
162 |
+
#cancel_btn::before {
|
163 |
+
content: url("data:image/svg+xml,%3Csvg width='21px' height='21px' viewBox='0 0 21 21' version='1.1' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink'%3E %3Cg id='pg' stroke='none' stroke-width='1' fill='none' fill-rule='evenodd'%3E %3Cpath d='M10.2072007,20.088463 C11.5727865,20.088463 12.8594566,19.8259823 14.067211,19.3010209 C15.2749653,18.7760595 16.3386126,18.0538087 17.2581528,17.1342685 C18.177693,16.2147282 18.8982283,15.1527965 19.4197586,13.9484733 C19.9412889,12.7441501 20.202054,11.4557644 20.202054,10.0833163 C20.202054,8.71773046 19.9395733,7.43106036 19.4146119,6.22330603 C18.8896505,5.01555169 18.1673997,3.95018885 17.2478595,3.0272175 C16.3283192,2.10424615 15.2646719,1.3837109 14.0569176,0.865611739 C12.8491633,0.34751258 11.5624932,0.088463 10.1969073,0.088463 C8.83132146,0.088463 7.54636692,0.34751258 6.34204371,0.865611739 C5.1377205,1.3837109 4.07407321,2.10424615 3.15110186,3.0272175 C2.22813051,3.95018885 1.5058797,5.01555169 0.984349419,6.22330603 C0.46281914,7.43106036 0.202054,8.71773046 0.202054,10.0833163 C0.202054,11.4557644 0.4645347,12.7441501 0.9894961,13.9484733 C1.5144575,15.1527965 2.23670831,16.2147282 3.15624854,17.1342685 C4.07578877,18.0538087 5.1377205,18.7760595 6.34204371,19.3010209 C7.54636692,19.8259823 8.83475258,20.088463 10.2072007,20.088463 Z M10.2072007,18.2562448 C9.07493099,18.2562448 8.01471483,18.0452309 7.0265522,17.6232031 C6.03838956,17.2011753 5.17031614,16.6161693 4.42233192,15.8681851 C3.6743477,15.1202009 3.09105726,14.2521274 2.67246059,13.2639648 C2.25386392,12.2758022 2.04456558,11.215586 2.04456558,10.0833163 C2.04456558,8.95104663 2.25386392,7.89083047 2.67246059,6.90266784 C3.09105726,5.9145052 3.6743477,5.04643178 4.42233192,4.29844756 C5.17031614,3.55046334 6.036674,2.9671729 7.02140552,2.54857623 C8.00613703,2.12997956 9.06463763,1.92068122 10.1969073,1.92068122 C11.329177,1.92068122 12.3911087,2.12997956 13.3827025,2.54857623 C14.3742962,2.9671729 15.2440852,3.55046334 15.9920694,4.29844756 C16.7400537,5.04643178 17.3233441,5.9145052 17.7419408,6.90266784 C18.1605374,7.89083047 18.3698358,8.95104663 18.3698358,10.0833163 C18.3698358,11.215586 18.1605374,12.2758022 17.7419408,13.2639648 C17.3233441,14.2521274 16.7400537,15.1202009 15.9920694,15.8681851 C15.2440852,16.6161693 14.3760118,17.2011753 13.3878492,17.6232031 C12.3996865,18.0452309 11.3394704,18.2562448 10.2072007,18.2562448 Z M7.65444721,13.6242324 L12.7496608,13.6242324 C13.0584616,13.6242324 13.3003556,13.5384544 13.4753427,13.3668984 C13.6503299,13.1953424 13.7378234,12.9585951 13.7378234,12.6566565 L13.7378234,7.49968276 C13.7378234,7.19774418 13.6503299,6.96099688 13.4753427,6.78944087 C13.3003556,6.61788486 13.0584616,6.53210685 12.7496608,6.53210685 L7.65444721,6.53210685 C7.33878414,6.53210685 7.09345904,6.61788486 6.91847191,6.78944087 C6.74348478,6.96099688 6.65599121,7.19774418 6.65599121,7.49968276 L6.65599121,12.6566565 C6.65599121,12.9585951 6.74348478,13.1953424 6.91847191,13.3668984 C7.09345904,13.5384544 7.33878414,13.6242324 7.65444721,13.6242324 Z' id='shape' fill='%23FF3B30' fill-rule='nonzero'%3E%3C/path%3E %3C/g%3E %3C/svg%3E");
|
164 |
+
height: 21px;
|
165 |
+
}
|
166 |
/* list */
|
167 |
ol:not(.options), ul:not(.options) {
|
168 |
padding-inline-start: 2em !important;
|
169 |
}
|
170 |
|
171 |
+
/* 亮色(默认) */
|
172 |
+
#chuanhu_chatbot {
|
173 |
+
background-color: var(--chatbot-color-light) !important;
|
174 |
+
color: #000000 !important;
|
175 |
+
}
|
176 |
+
[data-testid = "bot"] {
|
177 |
+
background-color: #FFFFFF !important;
|
178 |
+
}
|
179 |
+
[data-testid = "user"] {
|
180 |
+
background-color: #95EC69 !important;
|
181 |
+
}
|
182 |
+
/* 暗色 */
|
183 |
+
.dark #chuanhu_chatbot {
|
184 |
+
background-color: var(--chatbot-color-dark) !important;
|
185 |
+
color: #FFFFFF !important;
|
186 |
+
}
|
187 |
+
.dark [data-testid = "bot"] {
|
188 |
+
background-color: #2C2C2C !important;
|
189 |
+
}
|
190 |
+
.dark [data-testid = "user"] {
|
191 |
+
background-color: #26B561 !important;
|
192 |
+
}
|
193 |
+
|
194 |
+
/* 屏幕宽度大于等于500px的设备 */
|
195 |
+
/* update on 2023.4.8: 高度的细致调整已写入JavaScript */
|
196 |
+
@media screen and (min-width: 500px) {
|
197 |
#chuanhu_chatbot {
|
198 |
+
height: calc(100vh - 200px);
|
|
|
199 |
}
|
200 |
+
#chuanhu_chatbot .wrap {
|
201 |
+
max-height: calc(100vh - 200px - var(--line-sm)*1rem - 2*var(--block-label-margin) );
|
|
|
|
|
|
|
202 |
}
|
203 |
}
|
204 |
+
/* 屏幕宽度小于500px的设备 */
|
205 |
+
@media screen and (max-width: 499px) {
|
206 |
#chuanhu_chatbot {
|
207 |
+
height: calc(100vh - 140px);
|
|
|
208 |
}
|
209 |
+
#chuanhu_chatbot .wrap {
|
210 |
+
max-height: calc(100vh - 140px - var(--line-sm)*1rem - 2*var(--block-label-margin) );
|
211 |
}
|
212 |
+
[data-testid = "bot"] {
|
213 |
+
max-width: 98% !important;
|
214 |
}
|
215 |
+
#app_title h1{
|
216 |
+
letter-spacing: -1px; font-size: 22px;
|
217 |
}
|
218 |
}
|
219 |
/* 对话气泡 */
|
assets/custom.js
CHANGED
@@ -1,70 +1,224 @@
|
|
|
|
1 |
// custom javascript here
|
|
|
2 |
const MAX_HISTORY_LENGTH = 32;
|
3 |
|
4 |
var key_down_history = [];
|
5 |
var currentIndex = -1;
|
6 |
var user_input_ta;
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
var ga = document.getElementsByTagName("gradio-app");
|
9 |
var targetNode = ga[0];
|
10 |
-
var
|
|
|
|
|
|
|
11 |
for (var i = 0; i < mutations.length; i++) {
|
12 |
-
if (mutations[i].addedNodes.length) {
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
}
|
59 |
}
|
60 |
-
}
|
61 |
-
break;
|
62 |
}
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
}
|
65 |
-
|
66 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
// custom javascript here
|
3 |
+
|
4 |
const MAX_HISTORY_LENGTH = 32;
|
5 |
|
6 |
var key_down_history = [];
|
7 |
var currentIndex = -1;
|
8 |
var user_input_ta;
|
9 |
|
10 |
+
var gradioContainer = null;
|
11 |
+
var user_input_ta = null;
|
12 |
+
var user_input_tb = null;
|
13 |
+
var userInfoDiv = null;
|
14 |
+
var appTitleDiv = null;
|
15 |
+
var chatbot = null;
|
16 |
+
var apSwitch = null;
|
17 |
+
|
18 |
var ga = document.getElementsByTagName("gradio-app");
|
19 |
var targetNode = ga[0];
|
20 |
+
var isInIframe = (window.self !== window.top);
|
21 |
+
|
22 |
+
// gradio 页面加载好了么??? 我能动你的元素了么??
|
23 |
+
function gradioLoaded(mutations) {
|
24 |
for (var i = 0; i < mutations.length; i++) {
|
25 |
+
if (mutations[i].addedNodes.length) {
|
26 |
+
gradioContainer = document.querySelector(".gradio-container");
|
27 |
+
user_input_tb = document.getElementById('user_input_tb');
|
28 |
+
userInfoDiv = document.getElementById("user_info");
|
29 |
+
appTitleDiv = document.getElementById("app_title");
|
30 |
+
chatbot = document.querySelector('#chuanhu_chatbot');
|
31 |
+
apSwitch = document.querySelector('.apSwitch input[type="checkbox"]');
|
32 |
+
|
33 |
+
if (gradioContainer && apSwitch) { // gradioCainter 加载出来了没?
|
34 |
+
adjustDarkMode();
|
35 |
+
}
|
36 |
+
if (user_input_tb) { // user_input_tb 加载出来了没?
|
37 |
+
selectHistory();
|
38 |
+
}
|
39 |
+
if (userInfoDiv && appTitleDiv) { // userInfoDiv 和 appTitleDiv 加载出来了没?
|
40 |
+
setTimeout(showOrHideUserInfo(), 2000);
|
41 |
+
}
|
42 |
+
if (chatbot) { // chatbot 加载出来了没?
|
43 |
+
setChatbotHeight()
|
44 |
+
}
|
45 |
+
}
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
function selectHistory() {
|
50 |
+
user_input_ta = user_input_tb.querySelector("textarea");
|
51 |
+
if (user_input_ta) {
|
52 |
+
observer.disconnect(); // 停止监听
|
53 |
+
// 在 textarea 上监听 keydown 事件
|
54 |
+
user_input_ta.addEventListener("keydown", function (event) {
|
55 |
+
var value = user_input_ta.value.trim();
|
56 |
+
// 判断按下的是否为方向键
|
57 |
+
if (event.code === 'ArrowUp' || event.code === 'ArrowDown') {
|
58 |
+
// ���果按下的是方向键,且输入框中有内容,且历史记录中没有该内容,则不执行操作
|
59 |
+
if (value && key_down_history.indexOf(value) === -1)
|
60 |
+
return;
|
61 |
+
// 对于需要响应的动作,阻止默认行为。
|
62 |
+
event.preventDefault();
|
63 |
+
var length = key_down_history.length;
|
64 |
+
if (length === 0) {
|
65 |
+
currentIndex = -1; // 如果历史记录为空,直接将当前选中的记录重置
|
66 |
+
return;
|
67 |
+
}
|
68 |
+
if (currentIndex === -1) {
|
69 |
+
currentIndex = length;
|
70 |
+
}
|
71 |
+
if (event.code === 'ArrowUp' && currentIndex > 0) {
|
72 |
+
currentIndex--;
|
73 |
+
user_input_ta.value = key_down_history[currentIndex];
|
74 |
+
} else if (event.code === 'ArrowDown' && currentIndex < length - 1) {
|
75 |
+
currentIndex++;
|
76 |
+
user_input_ta.value = key_down_history[currentIndex];
|
77 |
+
}
|
78 |
+
user_input_ta.selectionStart = user_input_ta.value.length;
|
79 |
+
user_input_ta.selectionEnd = user_input_ta.value.length;
|
80 |
+
const input_event = new InputEvent("input", { bubbles: true, cancelable: true });
|
81 |
+
user_input_ta.dispatchEvent(input_event);
|
82 |
+
} else if (event.code === "Enter") {
|
83 |
+
if (value) {
|
84 |
+
currentIndex = -1;
|
85 |
+
if (key_down_history.indexOf(value) === -1) {
|
86 |
+
key_down_history.push(value);
|
87 |
+
if (key_down_history.length > MAX_HISTORY_LENGTH) {
|
88 |
+
key_down_history.shift();
|
89 |
}
|
90 |
}
|
91 |
+
}
|
|
|
92 |
}
|
93 |
+
});
|
94 |
+
}
|
95 |
+
}
|
96 |
+
|
97 |
+
function toggleUserInfoVisibility(shouldHide) {
|
98 |
+
if (userInfoDiv) {
|
99 |
+
if (shouldHide) {
|
100 |
+
userInfoDiv.classList.add("hideK");
|
101 |
+
} else {
|
102 |
+
userInfoDiv.classList.remove("hideK");
|
103 |
}
|
104 |
+
}
|
105 |
+
}
|
106 |
+
function showOrHideUserInfo() {
|
107 |
+
var sendBtn = document.getElementById("submit_btn");
|
108 |
+
|
109 |
+
// Bind mouse/touch events to show/hide user info
|
110 |
+
appTitleDiv.addEventListener("mouseenter", function () {
|
111 |
+
toggleUserInfoVisibility(false);
|
112 |
+
});
|
113 |
+
userInfoDiv.addEventListener("mouseenter", function () {
|
114 |
+
toggleUserInfoVisibility(false);
|
115 |
+
});
|
116 |
+
sendBtn.addEventListener("mouseenter", function () {
|
117 |
+
toggleUserInfoVisibility(false);
|
118 |
+
});
|
119 |
+
|
120 |
+
appTitleDiv.addEventListener("mouseleave", function () {
|
121 |
+
toggleUserInfoVisibility(true);
|
122 |
+
});
|
123 |
+
userInfoDiv.addEventListener("mouseleave", function () {
|
124 |
+
toggleUserInfoVisibility(true);
|
125 |
+
});
|
126 |
+
sendBtn.addEventListener("mouseleave", function () {
|
127 |
+
toggleUserInfoVisibility(true);
|
128 |
+
});
|
129 |
+
|
130 |
+
appTitleDiv.ontouchstart = function () {
|
131 |
+
toggleUserInfoVisibility(false);
|
132 |
+
};
|
133 |
+
userInfoDiv.ontouchstart = function () {
|
134 |
+
toggleUserInfoVisibility(false);
|
135 |
+
};
|
136 |
+
sendBtn.ontouchstart = function () {
|
137 |
+
toggleUserInfoVisibility(false);
|
138 |
+
};
|
139 |
+
|
140 |
+
appTitleDiv.ontouchend = function () {
|
141 |
+
setTimeout(function () {
|
142 |
+
toggleUserInfoVisibility(true);
|
143 |
+
}, 3000);
|
144 |
+
};
|
145 |
+
userInfoDiv.ontouchend = function () {
|
146 |
+
setTimeout(function () {
|
147 |
+
toggleUserInfoVisibility(true);
|
148 |
+
}, 3000);
|
149 |
+
};
|
150 |
+
sendBtn.ontouchend = function () {
|
151 |
+
setTimeout(function () {
|
152 |
+
toggleUserInfoVisibility(true);
|
153 |
+
}, 3000); // Delay 1 second to hide user info
|
154 |
+
};
|
155 |
+
|
156 |
+
// Hide user info after 2 second
|
157 |
+
setTimeout(function () {
|
158 |
+
toggleUserInfoVisibility(true);
|
159 |
+
}, 2000);
|
160 |
+
}
|
161 |
|
162 |
+
function toggleDarkMode(isEnabled) {
|
163 |
+
if (isEnabled) {
|
164 |
+
gradioContainer.classList.add("dark");
|
165 |
+
document.body.style.setProperty("background-color", "var(--neutral-950)", "important");
|
166 |
+
} else {
|
167 |
+
gradioContainer.classList.remove("dark");
|
168 |
+
document.body.style.backgroundColor = "";
|
169 |
+
}
|
170 |
+
}
|
171 |
+
function adjustDarkMode() {
|
172 |
+
const darkModeQuery = window.matchMedia("(prefers-color-scheme: dark)");
|
173 |
|
174 |
+
// 根据当前颜色模式设置初始状态
|
175 |
+
apSwitch.checked = darkModeQuery.matches;
|
176 |
+
toggleDarkMode(darkModeQuery.matches);
|
177 |
+
// 监听颜色模式变化
|
178 |
+
darkModeQuery.addEventListener("change", (e) => {
|
179 |
+
apSwitch.checked = e.matches;
|
180 |
+
toggleDarkMode(e.matches);
|
181 |
+
});
|
182 |
+
// apSwitch = document.querySelector('.apSwitch input[type="checkbox"]');
|
183 |
+
apSwitch.addEventListener("change", (e) => {
|
184 |
+
toggleDarkMode(e.target.checked);
|
185 |
+
});
|
186 |
+
}
|
187 |
+
|
188 |
+
function setChatbotHeight() {
|
189 |
+
const screenWidth = window.innerWidth;
|
190 |
+
const statusDisplay = document.querySelector('#status_display');
|
191 |
+
const statusDisplayHeight = statusDisplay ? statusDisplay.offsetHeight : 0;
|
192 |
+
const wrap = chatbot.querySelector('.wrap');
|
193 |
+
const vh = window.innerHeight * 0.01;
|
194 |
+
document.documentElement.style.setProperty('--vh', `${vh}px`);
|
195 |
+
if (isInIframe) {
|
196 |
+
chatbot.style.height = `700px`;
|
197 |
+
wrap.style.maxHeight = `calc(700px - var(--line-sm) * 1rem - 2 * var(--block-label-margin))`
|
198 |
+
} else {
|
199 |
+
if (screenWidth <= 320) {
|
200 |
+
chatbot.style.height = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 150}px)`;
|
201 |
+
wrap.style.maxHeight = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 150}px - var(--line-sm) * 1rem - 2 * var(--block-label-margin))`;
|
202 |
+
} else if (screenWidth <= 499) {
|
203 |
+
chatbot.style.height = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 100}px)`;
|
204 |
+
wrap.style.maxHeight = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 100}px - var(--line-sm) * 1rem - 2 * var(--block-label-margin))`;
|
205 |
+
} else {
|
206 |
+
chatbot.style.height = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 160}px)`;
|
207 |
+
wrap.style.maxHeight = `calc(var(--vh, 1vh) * 100 - ${statusDisplayHeight + 160}px - var(--line-sm) * 1rem - 2 * var(--block-label-margin))`;
|
208 |
+
}
|
209 |
+
}
|
210 |
+
}
|
211 |
+
|
212 |
+
// 监视页面内部 DOM 变动
|
213 |
+
var observer = new MutationObserver(function (mutations) {
|
214 |
+
gradioLoaded(mutations);
|
215 |
+
});
|
216 |
+
observer.observe(targetNode, { childList: true, subtree: true });
|
217 |
+
|
218 |
+
// 监视页面变化
|
219 |
+
window.addEventListener("DOMContentLoaded", function () {
|
220 |
+
isInIframe = (window.self !== window.top);
|
221 |
+
});
|
222 |
+
window.addEventListener('resize', setChatbotHeight);
|
223 |
+
window.addEventListener('scroll', setChatbotHeight);
|
224 |
+
window.matchMedia("(prefers-color-scheme: dark)").addEventListener("change", adjustDarkMode);
|
config_example.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// 你的OpenAI API Key,一般必填,
|
3 |
+
// 若缺省填为 "openai_api_key": "" 则必须再在图形界面中填入API Key
|
4 |
+
"openai_api_key": "",
|
5 |
+
// 如果使用代理,请取消注释下面的两行,并替换代理URL
|
6 |
+
// "https_proxy": "http://127.0.0.1:1079",
|
7 |
+
// "http_proxy": "http://127.0.0.1:1079",
|
8 |
+
"users": [], // 用户列表,[[用户名1, 密码1], [用户名2, 密码2], ...]
|
9 |
+
"local_embedding": false, //是否在本地编制索引
|
10 |
+
"default_model": "gpt-3.5-turbo", // 默认模型
|
11 |
+
"advance_docs": {
|
12 |
+
"pdf": {
|
13 |
+
// 是否认为PDF是双栏的
|
14 |
+
"two_column": false,
|
15 |
+
// 是否使用OCR识别PDF中的公式
|
16 |
+
"formula_ocr": true
|
17 |
+
}
|
18 |
+
},
|
19 |
+
// 是否多个API Key轮换使用
|
20 |
+
"multi_api_key": false,
|
21 |
+
"api_key_list": [
|
22 |
+
"sk-xxxxxxxxxxxxxxxxxxxxxxxx1",
|
23 |
+
"sk-xxxxxxxxxxxxxxxxxxxxxxxx2",
|
24 |
+
"sk-xxxxxxxxxxxxxxxxxxxxxxxx3"
|
25 |
+
],
|
26 |
+
// 如果使用自定义端口、自定义ip,请取消注释并替换对应内容
|
27 |
+
// "server_name": "0.0.0.0",
|
28 |
+
// "server_port": 7860,
|
29 |
+
// 如果要share到gradio,设置为true
|
30 |
+
// "share": false,
|
31 |
+
}
|
configs/ds_config_chatbot.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": false
|
4 |
+
},
|
5 |
+
"bf16": {
|
6 |
+
"enabled": true
|
7 |
+
},
|
8 |
+
"comms_logger": {
|
9 |
+
"enabled": false,
|
10 |
+
"verbose": false,
|
11 |
+
"prof_all": false,
|
12 |
+
"debug": false
|
13 |
+
},
|
14 |
+
"steps_per_print": 20000000000000000,
|
15 |
+
"train_micro_batch_size_per_gpu": 1,
|
16 |
+
"wall_clock_breakdown": false
|
17 |
+
}
|
modules/__init__.py
ADDED
File without changes
|
modules/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (172 Bytes). View file
|
|
modules/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (154 Bytes). View file
|
|
modules/__pycache__/base_model.cpython-311.pyc
ADDED
Binary file (26.7 kB). View file
|
|
modules/__pycache__/base_model.cpython-39.pyc
ADDED
Binary file (15.8 kB). View file
|
|
modules/__pycache__/config.cpython-311.pyc
ADDED
Binary file (7.87 kB). View file
|
|
modules/__pycache__/config.cpython-39.pyc
CHANGED
Binary files a/modules/__pycache__/config.cpython-39.pyc and b/modules/__pycache__/config.cpython-39.pyc differ
|
|
modules/__pycache__/llama_func.cpython-311.pyc
ADDED
Binary file (9.28 kB). View file
|
|
modules/__pycache__/models.cpython-311.pyc
ADDED
Binary file (30.6 kB). View file
|
|
modules/base_model.py
ADDED
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from typing import TYPE_CHECKING, List
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
+
import commentjson as cjson
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import requests
|
10 |
+
import urllib3
|
11 |
+
import traceback
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
import colorama
|
15 |
+
from duckduckgo_search import ddg
|
16 |
+
import asyncio
|
17 |
+
import aiohttp
|
18 |
+
from enum import Enum
|
19 |
+
|
20 |
+
from .presets import *
|
21 |
+
from .llama_func import *
|
22 |
+
from .utils import *
|
23 |
+
from . import shared
|
24 |
+
from .config import retrieve_proxy
|
25 |
+
|
26 |
+
|
27 |
+
class ModelType(Enum):
|
28 |
+
Unknown = -1
|
29 |
+
OpenAI = 0
|
30 |
+
ChatGLM = 1
|
31 |
+
LLaMA = 2
|
32 |
+
XMBot = 3
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def get_type(cls, model_name: str):
|
36 |
+
model_type = None
|
37 |
+
model_name_lower = model_name.lower()
|
38 |
+
if "gpt" in model_name_lower:
|
39 |
+
model_type = ModelType.OpenAI
|
40 |
+
elif "chatglm" in model_name_lower:
|
41 |
+
model_type = ModelType.ChatGLM
|
42 |
+
elif "llama" in model_name_lower or "alpaca" in model_name_lower:
|
43 |
+
model_type = ModelType.LLaMA
|
44 |
+
elif "xmbot" in model_name_lower:
|
45 |
+
model_type = ModelType.XMBot
|
46 |
+
else:
|
47 |
+
model_type = ModelType.Unknown
|
48 |
+
return model_type
|
49 |
+
|
50 |
+
|
51 |
+
class BaseLLMModel:
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
model_name,
|
55 |
+
system_prompt="",
|
56 |
+
temperature=1.0,
|
57 |
+
top_p=1.0,
|
58 |
+
n_choices=1,
|
59 |
+
stop=None,
|
60 |
+
max_generation_token=None,
|
61 |
+
presence_penalty=0,
|
62 |
+
frequency_penalty=0,
|
63 |
+
logit_bias=None,
|
64 |
+
user="",
|
65 |
+
) -> None:
|
66 |
+
self.history = []
|
67 |
+
self.all_token_counts = []
|
68 |
+
self.model_name = model_name
|
69 |
+
self.model_type = ModelType.get_type(model_name)
|
70 |
+
try:
|
71 |
+
self.token_upper_limit = MODEL_TOKEN_LIMIT[model_name]
|
72 |
+
except KeyError:
|
73 |
+
self.token_upper_limit = DEFAULT_TOKEN_LIMIT
|
74 |
+
self.interrupted = False
|
75 |
+
self.system_prompt = system_prompt
|
76 |
+
self.api_key = None
|
77 |
+
self.need_api_key = False
|
78 |
+
self.single_turn = False
|
79 |
+
|
80 |
+
self.temperature = temperature
|
81 |
+
self.top_p = top_p
|
82 |
+
self.n_choices = n_choices
|
83 |
+
self.stop_sequence = stop
|
84 |
+
self.max_generation_token = None
|
85 |
+
self.presence_penalty = presence_penalty
|
86 |
+
self.frequency_penalty = frequency_penalty
|
87 |
+
self.logit_bias = logit_bias
|
88 |
+
self.user_identifier = user
|
89 |
+
|
90 |
+
def get_answer_stream_iter(self):
|
91 |
+
"""stream predict, need to be implemented
|
92 |
+
conversations are stored in self.history, with the most recent question, in OpenAI format
|
93 |
+
should return a generator, each time give the next word (str) in the answer
|
94 |
+
"""
|
95 |
+
logging.warning("stream predict not implemented, using at once predict instead")
|
96 |
+
response, _ = self.get_answer_at_once()
|
97 |
+
yield response
|
98 |
+
|
99 |
+
def get_answer_at_once(self):
|
100 |
+
"""predict at once, need to be implemented
|
101 |
+
conversations are stored in self.history, with the most recent question, in OpenAI format
|
102 |
+
Should return:
|
103 |
+
the answer (str)
|
104 |
+
total token count (int)
|
105 |
+
"""
|
106 |
+
logging.warning("at once predict not implemented, using stream predict instead")
|
107 |
+
response_iter = self.get_answer_stream_iter()
|
108 |
+
count = 0
|
109 |
+
for response in response_iter:
|
110 |
+
count += 1
|
111 |
+
return response, sum(self.all_token_counts) + count
|
112 |
+
|
113 |
+
def billing_info(self):
|
114 |
+
"""get billing infomation, inplement if needed"""
|
115 |
+
logging.warning("billing info not implemented, using default")
|
116 |
+
return BILLING_NOT_APPLICABLE_MSG
|
117 |
+
|
118 |
+
def count_token(self, user_input):
|
119 |
+
"""get token count from input, implement if needed"""
|
120 |
+
logging.warning("token count not implemented, using default")
|
121 |
+
return len(user_input)
|
122 |
+
|
123 |
+
def stream_next_chatbot(self, inputs, chatbot, fake_input=None, display_append=""):
|
124 |
+
def get_return_value():
|
125 |
+
return chatbot, status_text
|
126 |
+
|
127 |
+
status_text = "开始实时传输回答……"
|
128 |
+
if fake_input:
|
129 |
+
chatbot.append((fake_input, ""))
|
130 |
+
else:
|
131 |
+
chatbot.append((inputs, ""))
|
132 |
+
|
133 |
+
user_token_count = self.count_token(inputs)
|
134 |
+
self.all_token_counts.append(user_token_count)
|
135 |
+
logging.debug(f"输入token计数: {user_token_count}")
|
136 |
+
|
137 |
+
stream_iter = self.get_answer_stream_iter()
|
138 |
+
|
139 |
+
for partial_text in stream_iter:
|
140 |
+
chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
|
141 |
+
self.all_token_counts[-1] += 1
|
142 |
+
status_text = self.token_message()
|
143 |
+
yield get_return_value()
|
144 |
+
if self.interrupted:
|
145 |
+
self.recover()
|
146 |
+
break
|
147 |
+
self.history.append(construct_assistant(partial_text))
|
148 |
+
|
149 |
+
def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
|
150 |
+
if fake_input:
|
151 |
+
chatbot.append((fake_input, ""))
|
152 |
+
else:
|
153 |
+
chatbot.append((inputs, ""))
|
154 |
+
if fake_input is not None:
|
155 |
+
user_token_count = self.count_token(fake_input)
|
156 |
+
else:
|
157 |
+
user_token_count = self.count_token(inputs)
|
158 |
+
self.all_token_counts.append(user_token_count)
|
159 |
+
ai_reply, total_token_count = self.get_answer_at_once()
|
160 |
+
self.history.append(construct_assistant(ai_reply))
|
161 |
+
if fake_input is not None:
|
162 |
+
self.history[-2] = construct_user(fake_input)
|
163 |
+
chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
|
164 |
+
if fake_input is not None:
|
165 |
+
self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
|
166 |
+
else:
|
167 |
+
self.all_token_counts[-1] = total_token_count - sum(self.all_token_counts)
|
168 |
+
status_text = self.token_message()
|
169 |
+
return chatbot, status_text
|
170 |
+
|
171 |
+
def handle_file_upload(self, files, chatbot):
|
172 |
+
"""if the model accepts multi modal input, implement this function"""
|
173 |
+
status = gr.Markdown.update()
|
174 |
+
if files:
|
175 |
+
construct_index(self.api_key, file_src=files)
|
176 |
+
status = "索引构建完成"
|
177 |
+
return gr.Files.update(), chatbot, status
|
178 |
+
|
179 |
+
def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
|
180 |
+
fake_inputs = None
|
181 |
+
display_append = []
|
182 |
+
limited_context = False
|
183 |
+
fake_inputs = real_inputs
|
184 |
+
if files:
|
185 |
+
from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
|
186 |
+
from llama_index.indices.query.schema import QueryBundle
|
187 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
188 |
+
from langchain.chat_models import ChatOpenAI
|
189 |
+
from llama_index import (
|
190 |
+
GPTSimpleVectorIndex,
|
191 |
+
ServiceContext,
|
192 |
+
LangchainEmbedding,
|
193 |
+
OpenAIEmbedding,
|
194 |
+
)
|
195 |
+
limited_context = True
|
196 |
+
msg = "加载索引中……"
|
197 |
+
logging.info(msg)
|
198 |
+
# yield chatbot + [(inputs, "")], msg
|
199 |
+
index = construct_index(self.api_key, file_src=files)
|
200 |
+
assert index is not None, "获取索引失败"
|
201 |
+
msg = "索引获取成功,生成回答中……"
|
202 |
+
logging.info(msg)
|
203 |
+
if local_embedding or self.model_type != ModelType.OpenAI:
|
204 |
+
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
|
205 |
+
else:
|
206 |
+
embed_model = OpenAIEmbedding()
|
207 |
+
# yield chatbot + [(inputs, "")], msg
|
208 |
+
with retrieve_proxy():
|
209 |
+
prompt_helper = PromptHelper(
|
210 |
+
max_input_size=4096,
|
211 |
+
num_output=5,
|
212 |
+
max_chunk_overlap=20,
|
213 |
+
chunk_size_limit=600,
|
214 |
+
)
|
215 |
+
from llama_index import ServiceContext
|
216 |
+
|
217 |
+
service_context = ServiceContext.from_defaults(
|
218 |
+
prompt_helper=prompt_helper, embed_model=embed_model
|
219 |
+
)
|
220 |
+
query_object = GPTVectorStoreIndexQuery(
|
221 |
+
index.index_struct,
|
222 |
+
service_context=service_context,
|
223 |
+
similarity_top_k=5,
|
224 |
+
vector_store=index._vector_store,
|
225 |
+
docstore=index._docstore,
|
226 |
+
)
|
227 |
+
query_bundle = QueryBundle(real_inputs)
|
228 |
+
nodes = query_object.retrieve(query_bundle)
|
229 |
+
reference_results = [n.node.text for n in nodes]
|
230 |
+
reference_results = add_source_numbers(reference_results, use_source=False)
|
231 |
+
display_append = add_details(reference_results)
|
232 |
+
display_append = "\n\n" + "".join(display_append)
|
233 |
+
real_inputs = (
|
234 |
+
replace_today(PROMPT_TEMPLATE)
|
235 |
+
.replace("{query_str}", real_inputs)
|
236 |
+
.replace("{context_str}", "\n\n".join(reference_results))
|
237 |
+
.replace("{reply_language}", reply_language)
|
238 |
+
)
|
239 |
+
elif use_websearch:
|
240 |
+
limited_context = True
|
241 |
+
search_results = ddg(real_inputs, max_results=5)
|
242 |
+
reference_results = []
|
243 |
+
for idx, result in enumerate(search_results):
|
244 |
+
logging.debug(f"搜索结果{idx + 1}:{result}")
|
245 |
+
domain_name = urllib3.util.parse_url(result["href"]).host
|
246 |
+
reference_results.append([result["body"], result["href"]])
|
247 |
+
display_append.append(
|
248 |
+
f"{idx+1}. [{domain_name}]({result['href']})\n"
|
249 |
+
)
|
250 |
+
reference_results = add_source_numbers(reference_results)
|
251 |
+
display_append = "\n\n" + "".join(display_append)
|
252 |
+
real_inputs = (
|
253 |
+
replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
|
254 |
+
.replace("{query}", real_inputs)
|
255 |
+
.replace("{web_results}", "\n\n".join(reference_results))
|
256 |
+
.replace("{reply_language}", reply_language)
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
display_append = ""
|
260 |
+
return limited_context, fake_inputs, display_append, real_inputs, chatbot
|
261 |
+
|
262 |
+
def predict(
|
263 |
+
self,
|
264 |
+
inputs,
|
265 |
+
chatbot,
|
266 |
+
stream=False,
|
267 |
+
use_websearch=False,
|
268 |
+
files=None,
|
269 |
+
reply_language="中文",
|
270 |
+
should_check_token_count=True,
|
271 |
+
): # repetition_penalty, top_k
|
272 |
+
|
273 |
+
status_text = "开始生成回答……"
|
274 |
+
logging.info(
|
275 |
+
"输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
|
276 |
+
)
|
277 |
+
if should_check_token_count:
|
278 |
+
yield chatbot + [(inputs, "")], status_text
|
279 |
+
if reply_language == "跟随问题语言(不稳定)":
|
280 |
+
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
281 |
+
|
282 |
+
limited_context, fake_inputs, display_append, inputs, chatbot = self.prepare_inputs(real_inputs=inputs, use_websearch=use_websearch, files=files, reply_language=reply_language, chatbot=chatbot)
|
283 |
+
yield chatbot + [(fake_inputs, "")], status_text
|
284 |
+
|
285 |
+
if (
|
286 |
+
self.need_api_key and
|
287 |
+
self.api_key is None
|
288 |
+
and not shared.state.multi_api_key
|
289 |
+
):
|
290 |
+
status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
|
291 |
+
logging.info(status_text)
|
292 |
+
chatbot.append((inputs, ""))
|
293 |
+
if len(self.history) == 0:
|
294 |
+
self.history.append(construct_user(inputs))
|
295 |
+
self.history.append("")
|
296 |
+
self.all_token_counts.append(0)
|
297 |
+
else:
|
298 |
+
self.history[-2] = construct_user(inputs)
|
299 |
+
yield chatbot + [(inputs, "")], status_text
|
300 |
+
return
|
301 |
+
elif len(inputs.strip()) == 0:
|
302 |
+
status_text = STANDARD_ERROR_MSG + NO_INPUT_MSG
|
303 |
+
logging.info(status_text)
|
304 |
+
yield chatbot + [(inputs, "")], status_text
|
305 |
+
return
|
306 |
+
|
307 |
+
if self.single_turn:
|
308 |
+
self.history = []
|
309 |
+
self.all_token_counts = []
|
310 |
+
self.history.append(construct_user(inputs))
|
311 |
+
|
312 |
+
try:
|
313 |
+
if stream:
|
314 |
+
logging.debug("使用流式传输")
|
315 |
+
iter = self.stream_next_chatbot(
|
316 |
+
inputs,
|
317 |
+
chatbot,
|
318 |
+
fake_input=fake_inputs,
|
319 |
+
display_append=display_append,
|
320 |
+
)
|
321 |
+
for chatbot, status_text in iter:
|
322 |
+
yield chatbot, status_text
|
323 |
+
else:
|
324 |
+
logging.debug("不使用流式传输")
|
325 |
+
chatbot, status_text = self.next_chatbot_at_once(
|
326 |
+
inputs,
|
327 |
+
chatbot,
|
328 |
+
fake_input=fake_inputs,
|
329 |
+
display_append=display_append,
|
330 |
+
)
|
331 |
+
yield chatbot, status_text
|
332 |
+
except Exception as e:
|
333 |
+
traceback.print_exc()
|
334 |
+
status_text = STANDARD_ERROR_MSG + str(e)
|
335 |
+
yield chatbot, status_text
|
336 |
+
|
337 |
+
if len(self.history) > 1 and self.history[-1]["content"] != inputs:
|
338 |
+
logging.info(
|
339 |
+
"回答为:"
|
340 |
+
+ colorama.Fore.BLUE
|
341 |
+
+ f"{self.history[-1]['content']}"
|
342 |
+
+ colorama.Style.RESET_ALL
|
343 |
+
)
|
344 |
+
|
345 |
+
if limited_context:
|
346 |
+
# self.history = self.history[-4:]
|
347 |
+
# self.all_token_counts = self.all_token_counts[-2:]
|
348 |
+
self.history = []
|
349 |
+
self.all_token_counts = []
|
350 |
+
|
351 |
+
max_token = self.token_upper_limit - TOKEN_OFFSET
|
352 |
+
|
353 |
+
if sum(self.all_token_counts) > max_token and should_check_token_count:
|
354 |
+
count = 0
|
355 |
+
while (
|
356 |
+
sum(self.all_token_counts)
|
357 |
+
> self.token_upper_limit * REDUCE_TOKEN_FACTOR
|
358 |
+
and sum(self.all_token_counts) > 0
|
359 |
+
):
|
360 |
+
count += 1
|
361 |
+
del self.all_token_counts[0]
|
362 |
+
del self.history[:2]
|
363 |
+
logging.info(status_text)
|
364 |
+
status_text = f"为了防止token超限,模型忘记了早期的 {count} 轮对话"
|
365 |
+
yield chatbot, status_text
|
366 |
+
|
367 |
+
def retry(
|
368 |
+
self,
|
369 |
+
chatbot,
|
370 |
+
stream=False,
|
371 |
+
use_websearch=False,
|
372 |
+
files=None,
|
373 |
+
reply_language="中文",
|
374 |
+
):
|
375 |
+
logging.debug("重试中……")
|
376 |
+
if len(self.history) == 0:
|
377 |
+
yield chatbot, f"{STANDARD_ERROR_MSG}上下文是空的"
|
378 |
+
return
|
379 |
+
|
380 |
+
inputs = self.history[-2]["content"]
|
381 |
+
del self.history[-2:]
|
382 |
+
self.all_token_counts.pop()
|
383 |
+
iter = self.predict(
|
384 |
+
inputs,
|
385 |
+
chatbot,
|
386 |
+
stream=stream,
|
387 |
+
use_websearch=use_websearch,
|
388 |
+
files=files,
|
389 |
+
reply_language=reply_language,
|
390 |
+
)
|
391 |
+
for x in iter:
|
392 |
+
yield x
|
393 |
+
logging.debug("重试完毕")
|
394 |
+
|
395 |
+
# def reduce_token_size(self, chatbot):
|
396 |
+
# logging.info("开始减少token数量……")
|
397 |
+
# chatbot, status_text = self.next_chatbot_at_once(
|
398 |
+
# summarize_prompt,
|
399 |
+
# chatbot
|
400 |
+
# )
|
401 |
+
# max_token_count = self.token_upper_limit * REDUCE_TOKEN_FACTOR
|
402 |
+
# num_chat = find_n(self.all_token_counts, max_token_count)
|
403 |
+
# logging.info(f"previous_token_count: {self.all_token_counts}, keeping {num_chat} chats")
|
404 |
+
# chatbot = chatbot[:-1]
|
405 |
+
# self.history = self.history[-2*num_chat:] if num_chat > 0 else []
|
406 |
+
# self.all_token_counts = self.all_token_counts[-num_chat:] if num_chat > 0 else []
|
407 |
+
# msg = f"保留了最近{num_chat}轮对话"
|
408 |
+
# logging.info(msg)
|
409 |
+
# logging.info("减少token数量完毕")
|
410 |
+
# return chatbot, msg + "," + self.token_message(self.all_token_counts if len(self.all_token_counts) > 0 else [0])
|
411 |
+
|
412 |
+
def interrupt(self):
|
413 |
+
self.interrupted = True
|
414 |
+
|
415 |
+
def recover(self):
|
416 |
+
self.interrupted = False
|
417 |
+
|
418 |
+
def set_token_upper_limit(self, new_upper_limit):
|
419 |
+
self.token_upper_limit = new_upper_limit
|
420 |
+
print(f"token上限设置为{new_upper_limit}")
|
421 |
+
|
422 |
+
def set_temperature(self, new_temperature):
|
423 |
+
self.temperature = new_temperature
|
424 |
+
|
425 |
+
def set_top_p(self, new_top_p):
|
426 |
+
self.top_p = new_top_p
|
427 |
+
|
428 |
+
def set_n_choices(self, new_n_choices):
|
429 |
+
self.n_choices = new_n_choices
|
430 |
+
|
431 |
+
def set_stop_sequence(self, new_stop_sequence: str):
|
432 |
+
new_stop_sequence = new_stop_sequence.split(",")
|
433 |
+
self.stop_sequence = new_stop_sequence
|
434 |
+
|
435 |
+
def set_max_tokens(self, new_max_tokens):
|
436 |
+
self.max_generation_token = new_max_tokens
|
437 |
+
|
438 |
+
def set_presence_penalty(self, new_presence_penalty):
|
439 |
+
self.presence_penalty = new_presence_penalty
|
440 |
+
|
441 |
+
def set_frequency_penalty(self, new_frequency_penalty):
|
442 |
+
self.frequency_penalty = new_frequency_penalty
|
443 |
+
|
444 |
+
def set_logit_bias(self, logit_bias):
|
445 |
+
logit_bias = logit_bias.split()
|
446 |
+
bias_map = {}
|
447 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
448 |
+
for line in logit_bias:
|
449 |
+
word, bias_amount = line.split(":")
|
450 |
+
if word:
|
451 |
+
for token in encoding.encode(word):
|
452 |
+
bias_map[token] = float(bias_amount)
|
453 |
+
self.logit_bias = bias_map
|
454 |
+
|
455 |
+
def set_user_identifier(self, new_user_identifier):
|
456 |
+
self.user_identifier = new_user_identifier
|
457 |
+
|
458 |
+
def set_system_prompt(self, new_system_prompt):
|
459 |
+
self.system_prompt = new_system_prompt
|
460 |
+
|
461 |
+
def set_key(self, new_access_key):
|
462 |
+
self.api_key = new_access_key.strip()
|
463 |
+
msg = f"API密钥更改为了{hide_middle_chars(self.api_key)}"
|
464 |
+
logging.info(msg)
|
465 |
+
return new_access_key, msg
|
466 |
+
|
467 |
+
def set_single_turn(self, new_single_turn):
|
468 |
+
self.single_turn = new_single_turn
|
469 |
+
|
470 |
+
def reset(self):
|
471 |
+
self.history = []
|
472 |
+
self.all_token_counts = []
|
473 |
+
self.interrupted = False
|
474 |
+
return [], self.token_message([0])
|
475 |
+
|
476 |
+
def delete_first_conversation(self):
|
477 |
+
if self.history:
|
478 |
+
del self.history[:2]
|
479 |
+
del self.all_token_counts[0]
|
480 |
+
return self.token_message()
|
481 |
+
|
482 |
+
def delete_last_conversation(self, chatbot):
|
483 |
+
if len(chatbot) > 0 and STANDARD_ERROR_MSG in chatbot[-1][1]:
|
484 |
+
msg = "由于包含报错信息,只删除chatbot记录"
|
485 |
+
chatbot.pop()
|
486 |
+
return chatbot, self.history
|
487 |
+
if len(self.history) > 0:
|
488 |
+
self.history.pop()
|
489 |
+
self.history.pop()
|
490 |
+
if len(chatbot) > 0:
|
491 |
+
msg = "删除了一组chatbot对话"
|
492 |
+
chatbot.pop()
|
493 |
+
if len(self.all_token_counts) > 0:
|
494 |
+
msg = "删除了一组对话的token计数记录"
|
495 |
+
self.all_token_counts.pop()
|
496 |
+
msg = "删除了一组对话"
|
497 |
+
return chatbot, msg
|
498 |
+
|
499 |
+
def token_message(self, token_lst=None):
|
500 |
+
if token_lst is None:
|
501 |
+
token_lst = self.all_token_counts
|
502 |
+
token_sum = 0
|
503 |
+
for i in range(len(token_lst)):
|
504 |
+
token_sum += sum(token_lst[: i + 1])
|
505 |
+
return f"Token 计数: {sum(token_lst)},本次对话累计消耗了 {token_sum} tokens"
|
506 |
+
|
507 |
+
def save_chat_history(self, filename, chatbot, user_name):
|
508 |
+
if filename == "":
|
509 |
+
return
|
510 |
+
if not filename.endswith(".json"):
|
511 |
+
filename += ".json"
|
512 |
+
return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
|
513 |
+
|
514 |
+
def export_markdown(self, filename, chatbot, user_name):
|
515 |
+
if filename == "":
|
516 |
+
return
|
517 |
+
if not filename.endswith(".md"):
|
518 |
+
filename += ".md"
|
519 |
+
return save_file(filename, self.system_prompt, self.history, chatbot, user_name)
|
520 |
+
|
521 |
+
def load_chat_history(self, filename, chatbot, user_name):
|
522 |
+
logging.debug(f"{user_name} 加载对话历史中……")
|
523 |
+
if type(filename) != str:
|
524 |
+
filename = filename.name
|
525 |
+
try:
|
526 |
+
with open(os.path.join(HISTORY_DIR, user_name, filename), "r") as f:
|
527 |
+
json_s = json.load(f)
|
528 |
+
try:
|
529 |
+
if type(json_s["history"][0]) == str:
|
530 |
+
logging.info("历史记录格式为旧版,正在转换……")
|
531 |
+
new_history = []
|
532 |
+
for index, item in enumerate(json_s["history"]):
|
533 |
+
if index % 2 == 0:
|
534 |
+
new_history.append(construct_user(item))
|
535 |
+
else:
|
536 |
+
new_history.append(construct_assistant(item))
|
537 |
+
json_s["history"] = new_history
|
538 |
+
logging.info(new_history)
|
539 |
+
except:
|
540 |
+
# 没有对话历史
|
541 |
+
pass
|
542 |
+
logging.debug(f"{user_name} 加载对话历史完毕")
|
543 |
+
self.history = json_s["history"]
|
544 |
+
return filename, json_s["system"], json_s["chatbot"]
|
545 |
+
except FileNotFoundError:
|
546 |
+
logging.warning(f"{user_name} 没有找到对话历史文件,不执行任何操作")
|
547 |
+
return filename, self.system_prompt, chatbot
|
modules/config.py
CHANGED
@@ -3,9 +3,10 @@ from contextlib import contextmanager
|
|
3 |
import os
|
4 |
import logging
|
5 |
import sys
|
6 |
-
import json
|
7 |
|
8 |
from . import shared
|
|
|
9 |
|
10 |
|
11 |
__all__ = [
|
@@ -18,6 +19,9 @@ __all__ = [
|
|
18 |
"advance_docs",
|
19 |
"update_doc_config",
|
20 |
"multi_api_key",
|
|
|
|
|
|
|
21 |
]
|
22 |
|
23 |
# 添加一个统一的config文件,避免文件过多造成的疑惑(优先级最低)
|
@@ -28,6 +32,30 @@ if os.path.exists("config.json"):
|
|
28 |
else:
|
29 |
config = {}
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
## 处理docker if we are running in Docker
|
32 |
dockerflag = config.get("dockerflag", False)
|
33 |
if os.environ.get("dockerrun") == "yes":
|
@@ -54,35 +82,6 @@ api_host = os.environ.get("api_host", config.get("api_host", ""))
|
|
54 |
if api_host:
|
55 |
shared.state.set_api_host(api_host)
|
56 |
|
57 |
-
if dockerflag:
|
58 |
-
if my_api_key == "empty":
|
59 |
-
logging.error("Please give a api key!")
|
60 |
-
sys.exit(1)
|
61 |
-
# auth
|
62 |
-
username = os.environ.get("USERNAME")
|
63 |
-
password = os.environ.get("PASSWORD")
|
64 |
-
if not (isinstance(username, type(None)) or isinstance(password, type(None))):
|
65 |
-
auth_list.append((os.environ.get("USERNAME"), os.environ.get("PASSWORD")))
|
66 |
-
authflag = True
|
67 |
-
else:
|
68 |
-
if (
|
69 |
-
not my_api_key
|
70 |
-
and os.path.exists("api_key.txt")
|
71 |
-
and os.path.getsize("api_key.txt")
|
72 |
-
):
|
73 |
-
with open("api_key.txt", "r") as f:
|
74 |
-
my_api_key = f.read().strip()
|
75 |
-
if os.path.exists("auth.json"):
|
76 |
-
authflag = True
|
77 |
-
with open("auth.json", "r", encoding='utf-8') as f:
|
78 |
-
auth = json.load(f)
|
79 |
-
for _ in auth:
|
80 |
-
if auth[_]["username"] and auth[_]["password"]:
|
81 |
-
auth_list.append((auth[_]["username"], auth[_]["password"]))
|
82 |
-
else:
|
83 |
-
logging.error("请检查auth.json文件中的用户名和密码!")
|
84 |
-
sys.exit(1)
|
85 |
-
|
86 |
@contextmanager
|
87 |
def retrieve_openai_api(api_key = None):
|
88 |
old_api_key = os.environ.get("OPENAI_API_KEY", "")
|
@@ -111,6 +110,8 @@ https_proxy = os.environ.get("HTTPS_PROXY", https_proxy)
|
|
111 |
os.environ["HTTP_PROXY"] = ""
|
112 |
os.environ["HTTPS_PROXY"] = ""
|
113 |
|
|
|
|
|
114 |
@contextmanager
|
115 |
def retrieve_proxy(proxy=None):
|
116 |
"""
|
@@ -137,9 +138,29 @@ advance_docs = defaultdict(lambda: defaultdict(dict))
|
|
137 |
advance_docs.update(config.get("advance_docs", {}))
|
138 |
def update_doc_config(two_column_pdf):
|
139 |
global advance_docs
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
else:
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
-
|
|
|
3 |
import os
|
4 |
import logging
|
5 |
import sys
|
6 |
+
import commentjson as json
|
7 |
|
8 |
from . import shared
|
9 |
+
from . import presets
|
10 |
|
11 |
|
12 |
__all__ = [
|
|
|
19 |
"advance_docs",
|
20 |
"update_doc_config",
|
21 |
"multi_api_key",
|
22 |
+
"server_name",
|
23 |
+
"server_port",
|
24 |
+
"share",
|
25 |
]
|
26 |
|
27 |
# 添加一个统一的config文件,避免文件过多造成的疑惑(优先级最低)
|
|
|
32 |
else:
|
33 |
config = {}
|
34 |
|
35 |
+
if os.path.exists("api_key.txt"):
|
36 |
+
logging.info("检测到api_key.txt文件,正在进行迁移...")
|
37 |
+
with open("api_key.txt", "r") as f:
|
38 |
+
config["openai_api_key"] = f.read().strip()
|
39 |
+
os.rename("api_key.txt", "api_key(deprecated).txt")
|
40 |
+
with open("config.json", "w", encoding='utf-8') as f:
|
41 |
+
json.dump(config, f, indent=4)
|
42 |
+
|
43 |
+
if os.path.exists("auth.json"):
|
44 |
+
logging.info("检测到auth.json文件,正在进行迁移...")
|
45 |
+
auth_list = []
|
46 |
+
with open("auth.json", "r", encoding='utf-8') as f:
|
47 |
+
auth = json.load(f)
|
48 |
+
for _ in auth:
|
49 |
+
if auth[_]["username"] and auth[_]["password"]:
|
50 |
+
auth_list.append((auth[_]["username"], auth[_]["password"]))
|
51 |
+
else:
|
52 |
+
logging.error("请检查auth.json文件中的用户名和密码!")
|
53 |
+
sys.exit(1)
|
54 |
+
config["users"] = auth_list
|
55 |
+
os.rename("auth.json", "auth(deprecated).json")
|
56 |
+
with open("config.json", "w", encoding='utf-8') as f:
|
57 |
+
json.dump(config, f, indent=4)
|
58 |
+
|
59 |
## 处理docker if we are running in Docker
|
60 |
dockerflag = config.get("dockerflag", False)
|
61 |
if os.environ.get("dockerrun") == "yes":
|
|
|
82 |
if api_host:
|
83 |
shared.state.set_api_host(api_host)
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
@contextmanager
|
86 |
def retrieve_openai_api(api_key = None):
|
87 |
old_api_key = os.environ.get("OPENAI_API_KEY", "")
|
|
|
110 |
os.environ["HTTP_PROXY"] = ""
|
111 |
os.environ["HTTPS_PROXY"] = ""
|
112 |
|
113 |
+
local_embedding = config.get("local_embedding", False) # 是否使用本地embedding
|
114 |
+
|
115 |
@contextmanager
|
116 |
def retrieve_proxy(proxy=None):
|
117 |
"""
|
|
|
138 |
advance_docs.update(config.get("advance_docs", {}))
|
139 |
def update_doc_config(two_column_pdf):
|
140 |
global advance_docs
|
141 |
+
advance_docs["pdf"]["two_column"] = two_column_pdf
|
142 |
+
|
143 |
+
logging.info(f"更新后的文件参数为:{advance_docs}")
|
144 |
+
|
145 |
+
## 处理gradio.launch参数
|
146 |
+
server_name = config.get("server_name", None)
|
147 |
+
server_port = config.get("server_port", None)
|
148 |
+
if server_name is None:
|
149 |
+
if dockerflag:
|
150 |
+
server_name = "0.0.0.0"
|
151 |
else:
|
152 |
+
server_name = "127.0.0.1"
|
153 |
+
if server_port is None:
|
154 |
+
if dockerflag:
|
155 |
+
server_port = 7860
|
156 |
+
|
157 |
+
assert server_port is None or type(server_port) == int, "要求port设置为int类型"
|
158 |
+
|
159 |
+
# 设置默认model
|
160 |
+
default_model = config.get("default_model", "")
|
161 |
+
try:
|
162 |
+
presets.DEFAULT_MODEL = presets.MODELS.index(default_model)
|
163 |
+
except ValueError:
|
164 |
+
pass
|
165 |
|
166 |
+
share = config.get("share", False)
|
modules/llama_func.py
CHANGED
@@ -15,6 +15,8 @@ from tqdm import tqdm
|
|
15 |
|
16 |
from modules.presets import *
|
17 |
from modules.utils import *
|
|
|
|
|
18 |
|
19 |
def get_index_name(file_src):
|
20 |
file_paths = [x.name for x in file_src]
|
@@ -28,6 +30,7 @@ def get_index_name(file_src):
|
|
28 |
|
29 |
return md5_hash.hexdigest()
|
30 |
|
|
|
31 |
def block_split(text):
|
32 |
blocks = []
|
33 |
while len(text) > 0:
|
@@ -35,6 +38,7 @@ def block_split(text):
|
|
35 |
text = text[1000:]
|
36 |
return blocks
|
37 |
|
|
|
38 |
def get_documents(file_src):
|
39 |
documents = []
|
40 |
logging.debug("Loading documents...")
|
@@ -44,37 +48,45 @@ def get_documents(file_src):
|
|
44 |
filename = os.path.basename(filepath)
|
45 |
file_type = os.path.splitext(filepath)[1]
|
46 |
logging.info(f"loading file: {filename}")
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
text = add_space(text_raw)
|
79 |
# text = block_split(text)
|
80 |
# documents += text
|
@@ -84,27 +96,36 @@ def get_documents(file_src):
|
|
84 |
|
85 |
|
86 |
def construct_index(
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
):
|
96 |
from langchain.chat_models import ChatOpenAI
|
97 |
-
from
|
|
|
98 |
|
99 |
-
|
|
|
|
|
|
|
|
|
100 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
101 |
embedding_limit = None if embedding_limit == 0 else embedding_limit
|
102 |
separator = " " if separator == "" else separator
|
103 |
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
106 |
)
|
107 |
-
prompt_helper = PromptHelper(max_input_size = max_input_size, num_output = num_outputs, max_chunk_overlap = max_chunk_overlap, embedding_limit=embedding_limit, chunk_size_limit=600, separator=separator)
|
108 |
index_name = get_index_name(file_src)
|
109 |
if os.path.exists(f"./index/{index_name}.json"):
|
110 |
logging.info("找到了缓存的索引文件,加载中……")
|
@@ -112,11 +133,19 @@ def construct_index(
|
|
112 |
else:
|
113 |
try:
|
114 |
documents = get_documents(file_src)
|
|
|
|
|
|
|
|
|
115 |
logging.info("构建索引中……")
|
116 |
with retrieve_proxy():
|
117 |
-
service_context = ServiceContext.from_defaults(
|
|
|
|
|
|
|
|
|
118 |
index = GPTSimpleVectorIndex.from_documents(
|
119 |
-
documents,
|
120 |
)
|
121 |
logging.debug("索引构建完成!")
|
122 |
os.makedirs("./index", exist_ok=True)
|
|
|
15 |
|
16 |
from modules.presets import *
|
17 |
from modules.utils import *
|
18 |
+
from modules.config import local_embedding
|
19 |
+
|
20 |
|
21 |
def get_index_name(file_src):
|
22 |
file_paths = [x.name for x in file_src]
|
|
|
30 |
|
31 |
return md5_hash.hexdigest()
|
32 |
|
33 |
+
|
34 |
def block_split(text):
|
35 |
blocks = []
|
36 |
while len(text) > 0:
|
|
|
38 |
text = text[1000:]
|
39 |
return blocks
|
40 |
|
41 |
+
|
42 |
def get_documents(file_src):
|
43 |
documents = []
|
44 |
logging.debug("Loading documents...")
|
|
|
48 |
filename = os.path.basename(filepath)
|
49 |
file_type = os.path.splitext(filepath)[1]
|
50 |
logging.info(f"loading file: {filename}")
|
51 |
+
try:
|
52 |
+
if file_type == ".pdf":
|
53 |
+
logging.debug("Loading PDF...")
|
54 |
+
try:
|
55 |
+
from modules.pdf_func import parse_pdf
|
56 |
+
from modules.config import advance_docs
|
57 |
+
|
58 |
+
two_column = advance_docs["pdf"].get("two_column", False)
|
59 |
+
pdftext = parse_pdf(filepath, two_column).text
|
60 |
+
except:
|
61 |
+
pdftext = ""
|
62 |
+
with open(filepath, "rb") as pdfFileObj:
|
63 |
+
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
64 |
+
for page in tqdm(pdfReader.pages):
|
65 |
+
pdftext += page.extract_text()
|
66 |
+
text_raw = pdftext
|
67 |
+
elif file_type == ".docx":
|
68 |
+
logging.debug("Loading Word...")
|
69 |
+
DocxReader = download_loader("DocxReader")
|
70 |
+
loader = DocxReader()
|
71 |
+
text_raw = loader.load_data(file=filepath)[0].text
|
72 |
+
elif file_type == ".epub":
|
73 |
+
logging.debug("Loading EPUB...")
|
74 |
+
EpubReader = download_loader("EpubReader")
|
75 |
+
loader = EpubReader()
|
76 |
+
text_raw = loader.load_data(file=filepath)[0].text
|
77 |
+
elif file_type == ".xlsx":
|
78 |
+
logging.debug("Loading Excel...")
|
79 |
+
text_list = excel_to_string(filepath)
|
80 |
+
for elem in text_list:
|
81 |
+
documents.append(Document(elem))
|
82 |
+
continue
|
83 |
+
else:
|
84 |
+
logging.debug("Loading text file...")
|
85 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
86 |
+
text_raw = f.read()
|
87 |
+
except Exception as e:
|
88 |
+
logging.error(f"Error loading file: {filename}")
|
89 |
+
pass
|
90 |
text = add_space(text_raw)
|
91 |
# text = block_split(text)
|
92 |
# documents += text
|
|
|
96 |
|
97 |
|
98 |
def construct_index(
|
99 |
+
api_key,
|
100 |
+
file_src,
|
101 |
+
max_input_size=4096,
|
102 |
+
num_outputs=5,
|
103 |
+
max_chunk_overlap=20,
|
104 |
+
chunk_size_limit=600,
|
105 |
+
embedding_limit=None,
|
106 |
+
separator=" ",
|
107 |
):
|
108 |
from langchain.chat_models import ChatOpenAI
|
109 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
110 |
+
from llama_index import GPTSimpleVectorIndex, ServiceContext, LangchainEmbedding, OpenAIEmbedding
|
111 |
|
112 |
+
if api_key:
|
113 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
114 |
+
else:
|
115 |
+
# 由于一个依赖的愚蠢的设计,这里必须要有一个API KEY
|
116 |
+
os.environ["OPENAI_API_KEY"] = "sk-xxxxxxx"
|
117 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
118 |
embedding_limit = None if embedding_limit == 0 else embedding_limit
|
119 |
separator = " " if separator == "" else separator
|
120 |
|
121 |
+
prompt_helper = PromptHelper(
|
122 |
+
max_input_size=max_input_size,
|
123 |
+
num_output=num_outputs,
|
124 |
+
max_chunk_overlap=max_chunk_overlap,
|
125 |
+
embedding_limit=embedding_limit,
|
126 |
+
chunk_size_limit=600,
|
127 |
+
separator=separator,
|
128 |
)
|
|
|
129 |
index_name = get_index_name(file_src)
|
130 |
if os.path.exists(f"./index/{index_name}.json"):
|
131 |
logging.info("找到了缓存的索引文件,加载中……")
|
|
|
133 |
else:
|
134 |
try:
|
135 |
documents = get_documents(file_src)
|
136 |
+
if local_embedding:
|
137 |
+
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
|
138 |
+
else:
|
139 |
+
embed_model = OpenAIEmbedding()
|
140 |
logging.info("构建索引中……")
|
141 |
with retrieve_proxy():
|
142 |
+
service_context = ServiceContext.from_defaults(
|
143 |
+
prompt_helper=prompt_helper,
|
144 |
+
chunk_size_limit=chunk_size_limit,
|
145 |
+
embed_model=embed_model,
|
146 |
+
)
|
147 |
index = GPTSimpleVectorIndex.from_documents(
|
148 |
+
documents, service_context=service_context
|
149 |
)
|
150 |
logging.debug("索引构建完成!")
|
151 |
os.makedirs("./index", exist_ok=True)
|
modules/models.py
ADDED
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from typing import TYPE_CHECKING, List
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
+
import commentjson as cjson
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import requests
|
10 |
+
import urllib3
|
11 |
+
import platform
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
import colorama
|
15 |
+
from duckduckgo_search import ddg
|
16 |
+
import asyncio
|
17 |
+
import aiohttp
|
18 |
+
from enum import Enum
|
19 |
+
import uuid
|
20 |
+
|
21 |
+
from .presets import *
|
22 |
+
from .llama_func import *
|
23 |
+
from .utils import *
|
24 |
+
from . import shared
|
25 |
+
from .config import retrieve_proxy
|
26 |
+
from modules import config
|
27 |
+
from .base_model import BaseLLMModel, ModelType
|
28 |
+
|
29 |
+
|
30 |
+
class OpenAIClient(BaseLLMModel):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
model_name,
|
34 |
+
api_key,
|
35 |
+
system_prompt=INITIAL_SYSTEM_PROMPT,
|
36 |
+
temperature=1.0,
|
37 |
+
top_p=1.0,
|
38 |
+
) -> None:
|
39 |
+
super().__init__(
|
40 |
+
model_name=model_name,
|
41 |
+
temperature=temperature,
|
42 |
+
top_p=top_p,
|
43 |
+
system_prompt=system_prompt,
|
44 |
+
)
|
45 |
+
self.api_key = api_key
|
46 |
+
self.need_api_key = True
|
47 |
+
self._refresh_header()
|
48 |
+
|
49 |
+
def get_answer_stream_iter(self):
|
50 |
+
response = self._get_response(stream=True)
|
51 |
+
if response is not None:
|
52 |
+
iter = self._decode_chat_response(response)
|
53 |
+
partial_text = ""
|
54 |
+
for i in iter:
|
55 |
+
partial_text += i
|
56 |
+
yield partial_text
|
57 |
+
else:
|
58 |
+
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
|
59 |
+
|
60 |
+
def get_answer_at_once(self):
|
61 |
+
response = self._get_response()
|
62 |
+
response = json.loads(response.text)
|
63 |
+
content = response["choices"][0]["message"]["content"]
|
64 |
+
total_token_count = response["usage"]["total_tokens"]
|
65 |
+
return content, total_token_count
|
66 |
+
|
67 |
+
def count_token(self, user_input):
|
68 |
+
input_token_count = count_token(construct_user(user_input))
|
69 |
+
if self.system_prompt is not None and len(self.all_token_counts) == 0:
|
70 |
+
system_prompt_token_count = count_token(
|
71 |
+
construct_system(self.system_prompt)
|
72 |
+
)
|
73 |
+
return input_token_count + system_prompt_token_count
|
74 |
+
return input_token_count
|
75 |
+
|
76 |
+
def billing_info(self):
|
77 |
+
try:
|
78 |
+
curr_time = datetime.datetime.now()
|
79 |
+
last_day_of_month = get_last_day_of_month(
|
80 |
+
curr_time).strftime("%Y-%m-%d")
|
81 |
+
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
|
82 |
+
usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
|
83 |
+
try:
|
84 |
+
usage_data = self._get_billing_data(usage_url)
|
85 |
+
except Exception as e:
|
86 |
+
logging.error(f"获取API使用情况失败:" + str(e))
|
87 |
+
return f"**获取API使用情况失败**"
|
88 |
+
rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
|
89 |
+
return f"**本月使用金额** \u3000 ${rounded_usage}"
|
90 |
+
except requests.exceptions.ConnectTimeout:
|
91 |
+
status_text = (
|
92 |
+
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
|
93 |
+
)
|
94 |
+
return status_text
|
95 |
+
except requests.exceptions.ReadTimeout:
|
96 |
+
status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
|
97 |
+
return status_text
|
98 |
+
except Exception as e:
|
99 |
+
logging.error(f"获取API使用情况失败:" + str(e))
|
100 |
+
return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
|
101 |
+
|
102 |
+
def set_token_upper_limit(self, new_upper_limit):
|
103 |
+
pass
|
104 |
+
|
105 |
+
def set_key(self, new_access_key):
|
106 |
+
self.api_key = new_access_key.strip()
|
107 |
+
self._refresh_header()
|
108 |
+
msg = f"API密钥更改为了{hide_middle_chars(self.api_key)}"
|
109 |
+
logging.info(msg)
|
110 |
+
return msg
|
111 |
+
|
112 |
+
@shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
|
113 |
+
def _get_response(self, stream=False):
|
114 |
+
openai_api_key = self.api_key
|
115 |
+
system_prompt = self.system_prompt
|
116 |
+
history = self.history
|
117 |
+
logging.debug(colorama.Fore.YELLOW +
|
118 |
+
f"{history}" + colorama.Fore.RESET)
|
119 |
+
headers = {
|
120 |
+
"Content-Type": "application/json",
|
121 |
+
"Authorization": f"Bearer {openai_api_key}",
|
122 |
+
}
|
123 |
+
|
124 |
+
if system_prompt is not None:
|
125 |
+
history = [construct_system(system_prompt), *history]
|
126 |
+
|
127 |
+
payload = {
|
128 |
+
"model": self.model_name,
|
129 |
+
"messages": history,
|
130 |
+
"temperature": self.temperature,
|
131 |
+
"top_p": self.top_p,
|
132 |
+
"n": self.n_choices,
|
133 |
+
"stream": stream,
|
134 |
+
"presence_penalty": self.presence_penalty,
|
135 |
+
"frequency_penalty": self.frequency_penalty,
|
136 |
+
}
|
137 |
+
|
138 |
+
if self.max_generation_token is not None:
|
139 |
+
payload["max_tokens"] = self.max_generation_token
|
140 |
+
if self.stop_sequence is not None:
|
141 |
+
payload["stop"] = self.stop_sequence
|
142 |
+
if self.logit_bias is not None:
|
143 |
+
payload["logit_bias"] = self.logit_bias
|
144 |
+
if self.user_identifier is not None:
|
145 |
+
payload["user"] = self.user_identifier
|
146 |
+
|
147 |
+
if stream:
|
148 |
+
timeout = TIMEOUT_STREAMING
|
149 |
+
else:
|
150 |
+
timeout = TIMEOUT_ALL
|
151 |
+
|
152 |
+
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
|
153 |
+
if shared.state.completion_url != COMPLETION_URL:
|
154 |
+
logging.info(f"使用自定义API URL: {shared.state.completion_url}")
|
155 |
+
|
156 |
+
with retrieve_proxy():
|
157 |
+
try:
|
158 |
+
response = requests.post(
|
159 |
+
shared.state.completion_url,
|
160 |
+
headers=headers,
|
161 |
+
json=payload,
|
162 |
+
stream=stream,
|
163 |
+
timeout=timeout,
|
164 |
+
)
|
165 |
+
except:
|
166 |
+
return None
|
167 |
+
return response
|
168 |
+
|
169 |
+
def _refresh_header(self):
|
170 |
+
self.headers = {
|
171 |
+
"Content-Type": "application/json",
|
172 |
+
"Authorization": f"Bearer {self.api_key}",
|
173 |
+
}
|
174 |
+
|
175 |
+
def _get_billing_data(self, billing_url):
|
176 |
+
with retrieve_proxy():
|
177 |
+
response = requests.get(
|
178 |
+
billing_url,
|
179 |
+
headers=self.headers,
|
180 |
+
timeout=TIMEOUT_ALL,
|
181 |
+
)
|
182 |
+
|
183 |
+
if response.status_code == 200:
|
184 |
+
data = response.json()
|
185 |
+
return data
|
186 |
+
else:
|
187 |
+
raise Exception(
|
188 |
+
f"API request failed with status code {response.status_code}: {response.text}"
|
189 |
+
)
|
190 |
+
|
191 |
+
def _decode_chat_response(self, response):
|
192 |
+
error_msg = ""
|
193 |
+
for chunk in response.iter_lines():
|
194 |
+
if chunk:
|
195 |
+
chunk = chunk.decode()
|
196 |
+
chunk_length = len(chunk)
|
197 |
+
try:
|
198 |
+
chunk = json.loads(chunk[6:])
|
199 |
+
except json.JSONDecodeError:
|
200 |
+
print(f"JSON解析错误,收到的内容: {chunk}")
|
201 |
+
error_msg += chunk
|
202 |
+
continue
|
203 |
+
if chunk_length > 6 and "delta" in chunk["choices"][0]:
|
204 |
+
if chunk["choices"][0]["finish_reason"] == "stop":
|
205 |
+
break
|
206 |
+
try:
|
207 |
+
yield chunk["choices"][0]["delta"]["content"]
|
208 |
+
except Exception as e:
|
209 |
+
# logging.error(f"Error: {e}")
|
210 |
+
continue
|
211 |
+
if error_msg:
|
212 |
+
raise Exception(error_msg)
|
213 |
+
|
214 |
+
|
215 |
+
class ChatGLM_Client(BaseLLMModel):
|
216 |
+
def __init__(self, model_name) -> None:
|
217 |
+
super().__init__(model_name=model_name)
|
218 |
+
from transformers import AutoTokenizer, AutoModel
|
219 |
+
import torch
|
220 |
+
global CHATGLM_TOKENIZER, CHATGLM_MODEL
|
221 |
+
if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
|
222 |
+
system_name = platform.system()
|
223 |
+
model_path = None
|
224 |
+
if os.path.exists("models"):
|
225 |
+
model_dirs = os.listdir("models")
|
226 |
+
if model_name in model_dirs:
|
227 |
+
model_path = f"models/{model_name}"
|
228 |
+
if model_path is not None:
|
229 |
+
model_source = model_path
|
230 |
+
else:
|
231 |
+
model_source = f"THUDM/{model_name}"
|
232 |
+
CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(
|
233 |
+
model_source, trust_remote_code=True
|
234 |
+
)
|
235 |
+
quantified = False
|
236 |
+
if "int4" in model_name:
|
237 |
+
quantified = True
|
238 |
+
if quantified:
|
239 |
+
model = AutoModel.from_pretrained(
|
240 |
+
model_source, trust_remote_code=True
|
241 |
+
).half()
|
242 |
+
else:
|
243 |
+
model = AutoModel.from_pretrained(
|
244 |
+
model_source, trust_remote_code=True
|
245 |
+
).half()
|
246 |
+
if torch.cuda.is_available():
|
247 |
+
# run on CUDA
|
248 |
+
logging.info("CUDA is available, using CUDA")
|
249 |
+
model = model.cuda()
|
250 |
+
# mps加速还存在一些问题,暂时不使用
|
251 |
+
elif system_name == "Darwin" and model_path is not None and not quantified:
|
252 |
+
logging.info("Running on macOS, using MPS")
|
253 |
+
# running on macOS and model already downloaded
|
254 |
+
model = model.to("mps")
|
255 |
+
else:
|
256 |
+
logging.info("GPU is not available, using CPU")
|
257 |
+
model = model.eval()
|
258 |
+
CHATGLM_MODEL = model
|
259 |
+
|
260 |
+
def _get_glm_style_input(self):
|
261 |
+
history = [x["content"] for x in self.history]
|
262 |
+
query = history.pop()
|
263 |
+
logging.debug(colorama.Fore.YELLOW +
|
264 |
+
f"{history}" + colorama.Fore.RESET)
|
265 |
+
assert (
|
266 |
+
len(history) % 2 == 0
|
267 |
+
), f"History should be even length. current history is: {history}"
|
268 |
+
history = [[history[i], history[i + 1]]
|
269 |
+
for i in range(0, len(history), 2)]
|
270 |
+
return history, query
|
271 |
+
|
272 |
+
def get_answer_at_once(self):
|
273 |
+
history, query = self._get_glm_style_input()
|
274 |
+
response, _ = CHATGLM_MODEL.chat(
|
275 |
+
CHATGLM_TOKENIZER, query, history=history)
|
276 |
+
return response, len(response)
|
277 |
+
|
278 |
+
def get_answer_stream_iter(self):
|
279 |
+
history, query = self._get_glm_style_input()
|
280 |
+
for response, history in CHATGLM_MODEL.stream_chat(
|
281 |
+
CHATGLM_TOKENIZER,
|
282 |
+
query,
|
283 |
+
history,
|
284 |
+
max_length=self.token_upper_limit,
|
285 |
+
top_p=self.top_p,
|
286 |
+
temperature=self.temperature,
|
287 |
+
):
|
288 |
+
yield response
|
289 |
+
|
290 |
+
|
291 |
+
class LLaMA_Client(BaseLLMModel):
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
model_name,
|
295 |
+
lora_path=None,
|
296 |
+
) -> None:
|
297 |
+
super().__init__(model_name=model_name)
|
298 |
+
from lmflow.datasets.dataset import Dataset
|
299 |
+
from lmflow.pipeline.auto_pipeline import AutoPipeline
|
300 |
+
from lmflow.models.auto_model import AutoModel
|
301 |
+
from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
|
302 |
+
|
303 |
+
self.max_generation_token = 1000
|
304 |
+
self.end_string = "\n\n"
|
305 |
+
# We don't need input data
|
306 |
+
data_args = DatasetArguments(dataset_path=None)
|
307 |
+
self.dataset = Dataset(data_args)
|
308 |
+
self.system_prompt = ""
|
309 |
+
|
310 |
+
global LLAMA_MODEL, LLAMA_INFERENCER
|
311 |
+
if LLAMA_MODEL is None or LLAMA_INFERENCER is None:
|
312 |
+
model_path = None
|
313 |
+
if os.path.exists("models"):
|
314 |
+
model_dirs = os.listdir("models")
|
315 |
+
if model_name in model_dirs:
|
316 |
+
model_path = f"models/{model_name}"
|
317 |
+
if model_path is not None:
|
318 |
+
model_source = model_path
|
319 |
+
else:
|
320 |
+
model_source = f"decapoda-research/{model_name}"
|
321 |
+
# raise Exception(f"models目录下没有这个模型: {model_name}")
|
322 |
+
if lora_path is not None:
|
323 |
+
lora_path = f"lora/{lora_path}"
|
324 |
+
model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None,
|
325 |
+
use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
|
326 |
+
pipeline_args = InferencerArguments(
|
327 |
+
local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
|
328 |
+
|
329 |
+
with open(pipeline_args.deepspeed, "r") as f:
|
330 |
+
ds_config = json.load(f)
|
331 |
+
LLAMA_MODEL = AutoModel.get_model(
|
332 |
+
model_args,
|
333 |
+
tune_strategy="none",
|
334 |
+
ds_config=ds_config,
|
335 |
+
)
|
336 |
+
LLAMA_INFERENCER = AutoPipeline.get_pipeline(
|
337 |
+
pipeline_name="inferencer",
|
338 |
+
model_args=model_args,
|
339 |
+
data_args=data_args,
|
340 |
+
pipeline_args=pipeline_args,
|
341 |
+
)
|
342 |
+
# Chats
|
343 |
+
# model_name = model_args.model_name_or_path
|
344 |
+
# if model_args.lora_model_path is not None:
|
345 |
+
# model_name += f" + {model_args.lora_model_path}"
|
346 |
+
|
347 |
+
# context = (
|
348 |
+
# "You are a helpful assistant who follows the given instructions"
|
349 |
+
# " unconditionally."
|
350 |
+
# )
|
351 |
+
|
352 |
+
def _get_llama_style_input(self):
|
353 |
+
history = []
|
354 |
+
instruction = ""
|
355 |
+
if self.system_prompt:
|
356 |
+
instruction = (f"Instruction: {self.system_prompt}\n")
|
357 |
+
for x in self.history:
|
358 |
+
if x["role"] == "user":
|
359 |
+
history.append(f"{instruction}Input: {x['content']}")
|
360 |
+
else:
|
361 |
+
history.append(f"Output: {x['content']}")
|
362 |
+
context = "\n\n".join(history)
|
363 |
+
context += "\n\nOutput: "
|
364 |
+
return context
|
365 |
+
|
366 |
+
def get_answer_at_once(self):
|
367 |
+
context = self._get_llama_style_input()
|
368 |
+
|
369 |
+
input_dataset = self.dataset.from_dict(
|
370 |
+
{"type": "text_only", "instances": [{"text": context}]}
|
371 |
+
)
|
372 |
+
|
373 |
+
output_dataset = LLAMA_INFERENCER.inference(
|
374 |
+
model=LLAMA_MODEL,
|
375 |
+
dataset=input_dataset,
|
376 |
+
max_new_tokens=self.max_generation_token,
|
377 |
+
temperature=self.temperature,
|
378 |
+
)
|
379 |
+
|
380 |
+
response = output_dataset.to_dict()["instances"][0]["text"]
|
381 |
+
return response, len(response)
|
382 |
+
|
383 |
+
def get_answer_stream_iter(self):
|
384 |
+
context = self._get_llama_style_input()
|
385 |
+
partial_text = ""
|
386 |
+
step = 1
|
387 |
+
for _ in range(0, self.max_generation_token, step):
|
388 |
+
input_dataset = self.dataset.from_dict(
|
389 |
+
{"type": "text_only", "instances": [
|
390 |
+
{"text": context + partial_text}]}
|
391 |
+
)
|
392 |
+
output_dataset = LLAMA_INFERENCER.inference(
|
393 |
+
model=LLAMA_MODEL,
|
394 |
+
dataset=input_dataset,
|
395 |
+
max_new_tokens=step,
|
396 |
+
temperature=self.temperature,
|
397 |
+
)
|
398 |
+
response = output_dataset.to_dict()["instances"][0]["text"]
|
399 |
+
if response == "" or response == self.end_string:
|
400 |
+
break
|
401 |
+
partial_text += response
|
402 |
+
yield partial_text
|
403 |
+
|
404 |
+
|
405 |
+
class XMBot_Client(BaseLLMModel):
|
406 |
+
def __init__(self, api_key):
|
407 |
+
super().__init__(model_name="xmbot")
|
408 |
+
self.api_key = api_key
|
409 |
+
self.session_id = None
|
410 |
+
self.reset()
|
411 |
+
self.image_bytes = None
|
412 |
+
self.image_path = None
|
413 |
+
self.xm_history = []
|
414 |
+
self.url = "https://xmbot.net/web"
|
415 |
+
|
416 |
+
def reset(self):
|
417 |
+
self.session_id = str(uuid.uuid4())
|
418 |
+
return [], "已重置"
|
419 |
+
|
420 |
+
def try_read_image(self, filepath):
|
421 |
+
import base64
|
422 |
+
|
423 |
+
def is_image_file(filepath):
|
424 |
+
# 判断文件是否为图片
|
425 |
+
valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
|
426 |
+
file_extension = os.path.splitext(filepath)[1].lower()
|
427 |
+
return file_extension in valid_image_extensions
|
428 |
+
|
429 |
+
def read_image_as_bytes(filepath):
|
430 |
+
# 读取图片文件并返回比特流
|
431 |
+
with open(filepath, "rb") as f:
|
432 |
+
image_bytes = f.read()
|
433 |
+
return image_bytes
|
434 |
+
|
435 |
+
if is_image_file(filepath):
|
436 |
+
logging.info(f"读取图片文件: {filepath}")
|
437 |
+
image_bytes = read_image_as_bytes(filepath)
|
438 |
+
base64_encoded_image = base64.b64encode(image_bytes).decode()
|
439 |
+
self.image_bytes = base64_encoded_image
|
440 |
+
self.image_path = filepath
|
441 |
+
else:
|
442 |
+
self.image_bytes = None
|
443 |
+
self.image_path = None
|
444 |
+
|
445 |
+
def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
|
446 |
+
fake_inputs = real_inputs
|
447 |
+
display_append = ""
|
448 |
+
limited_context = False
|
449 |
+
return limited_context, fake_inputs, display_append, real_inputs, chatbot
|
450 |
+
|
451 |
+
def handle_file_upload(self, files, chatbot):
|
452 |
+
"""if the model accepts multi modal input, implement this function"""
|
453 |
+
if files:
|
454 |
+
for file in files:
|
455 |
+
if file.name:
|
456 |
+
logging.info(f"尝试读取图像: {file.name}")
|
457 |
+
self.try_read_image(file.name)
|
458 |
+
if self.image_path is not None:
|
459 |
+
chatbot = chatbot + [((self.image_path,), None)]
|
460 |
+
if self.image_bytes is not None:
|
461 |
+
logging.info("使用图片作为输入")
|
462 |
+
conv_id = str(uuid.uuid4())
|
463 |
+
data = {
|
464 |
+
"user_id": self.api_key,
|
465 |
+
"session_id": self.session_id,
|
466 |
+
"uuid": conv_id,
|
467 |
+
"data_type": "imgbase64",
|
468 |
+
"data": self.image_bytes
|
469 |
+
}
|
470 |
+
response = requests.post(self.url, json=data)
|
471 |
+
response = json.loads(response.text)
|
472 |
+
logging.info(f"图片回复: {response['data']}")
|
473 |
+
return None, chatbot, None
|
474 |
+
|
475 |
+
def get_answer_at_once(self):
|
476 |
+
question = self.history[-1]["content"]
|
477 |
+
conv_id = str(uuid.uuid4())
|
478 |
+
data = {
|
479 |
+
"user_id": self.api_key,
|
480 |
+
"session_id": self.session_id,
|
481 |
+
"uuid": conv_id,
|
482 |
+
"data_type": "text",
|
483 |
+
"data": question
|
484 |
+
}
|
485 |
+
response = requests.post(self.url, json=data)
|
486 |
+
response = json.loads(response.text)
|
487 |
+
return response["data"], len(response["data"])
|
488 |
+
|
489 |
+
|
490 |
+
|
491 |
+
|
492 |
+
def get_model(
|
493 |
+
model_name,
|
494 |
+
lora_model_path=None,
|
495 |
+
access_key=None,
|
496 |
+
temperature=None,
|
497 |
+
top_p=None,
|
498 |
+
system_prompt=None,
|
499 |
+
) -> BaseLLMModel:
|
500 |
+
msg = f"模型设置为了: {model_name}"
|
501 |
+
model_type = ModelType.get_type(model_name)
|
502 |
+
lora_selector_visibility = False
|
503 |
+
lora_choices = []
|
504 |
+
dont_change_lora_selector = False
|
505 |
+
if model_type != ModelType.OpenAI:
|
506 |
+
config.local_embedding = True
|
507 |
+
# del current_model.model
|
508 |
+
model = None
|
509 |
+
try:
|
510 |
+
if model_type == ModelType.OpenAI:
|
511 |
+
logging.info(f"正在加载OpenAI模型: {model_name}")
|
512 |
+
model = OpenAIClient(
|
513 |
+
model_name=model_name,
|
514 |
+
api_key=access_key,
|
515 |
+
system_prompt=system_prompt,
|
516 |
+
temperature=temperature,
|
517 |
+
top_p=top_p,
|
518 |
+
)
|
519 |
+
elif model_type == ModelType.ChatGLM:
|
520 |
+
logging.info(f"正在加载ChatGLM模型: {model_name}")
|
521 |
+
model = ChatGLM_Client(model_name)
|
522 |
+
elif model_type == ModelType.LLaMA and lora_model_path == "":
|
523 |
+
msg = f"现在请为 {model_name} 选择LoRA模型"
|
524 |
+
logging.info(msg)
|
525 |
+
lora_selector_visibility = True
|
526 |
+
if os.path.isdir("lora"):
|
527 |
+
lora_choices = get_file_names(
|
528 |
+
"lora", plain=True, filetypes=[""])
|
529 |
+
lora_choices = ["No LoRA"] + lora_choices
|
530 |
+
elif model_type == ModelType.LLaMA and lora_model_path != "":
|
531 |
+
logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
|
532 |
+
dont_change_lora_selector = True
|
533 |
+
if lora_model_path == "No LoRA":
|
534 |
+
lora_model_path = None
|
535 |
+
msg += " + No LoRA"
|
536 |
+
else:
|
537 |
+
msg += f" + {lora_model_path}"
|
538 |
+
model = LLaMA_Client(model_name, lora_model_path)
|
539 |
+
elif model_type == ModelType.XMBot:
|
540 |
+
model = XMBot_Client(api_key=access_key)
|
541 |
+
elif model_type == ModelType.Unknown:
|
542 |
+
raise ValueError(f"未知模型: {model_name}")
|
543 |
+
logging.info(msg)
|
544 |
+
except Exception as e:
|
545 |
+
logging.error(e)
|
546 |
+
msg = f"{STANDARD_ERROR_MSG}: {e}"
|
547 |
+
if dont_change_lora_selector:
|
548 |
+
return model, msg
|
549 |
+
else:
|
550 |
+
return model, msg, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility)
|
551 |
+
|
552 |
+
|
553 |
+
if __name__ == "__main__":
|
554 |
+
with open("config.json", "r") as f:
|
555 |
+
openai_api_key = cjson.load(f)["openai_api_key"]
|
556 |
+
# set logging level to debug
|
557 |
+
logging.basicConfig(level=logging.DEBUG)
|
558 |
+
# client = ModelManager(model_name="gpt-3.5-turbo", access_key=openai_api_key)
|
559 |
+
client = get_model(model_name="chatglm-6b-int4")
|
560 |
+
chatbot = []
|
561 |
+
stream = False
|
562 |
+
# 测试账单功能
|
563 |
+
logging.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
|
564 |
+
logging.info(client.billing_info())
|
565 |
+
# 测试问答
|
566 |
+
logging.info(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET)
|
567 |
+
question = "巴黎是中国的首都吗?"
|
568 |
+
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
|
569 |
+
logging.info(i)
|
570 |
+
logging.info(f"测试问答后history : {client.history}")
|
571 |
+
# 测试记忆力
|
572 |
+
logging.info(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET)
|
573 |
+
question = "我刚刚问了你什么问题?"
|
574 |
+
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
|
575 |
+
logging.info(i)
|
576 |
+
logging.info(f"测试记忆力后history : {client.history}")
|
577 |
+
# 测试重试功能
|
578 |
+
logging.info(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET)
|
579 |
+
for i in client.retry(chatbot=chatbot, stream=stream):
|
580 |
+
logging.info(i)
|
581 |
+
logging.info(f"重试后history : {client.history}")
|
582 |
+
# # 测试总结功能
|
583 |
+
# print(colorama.Back.GREEN + "测试总结功能" + colorama.Back.RESET)
|
584 |
+
# chatbot, msg = client.reduce_token_size(chatbot=chatbot)
|
585 |
+
# print(chatbot, msg)
|
586 |
+
# print(f"总结后history: {client.history}")
|
modules/overwrites.py
CHANGED
@@ -4,6 +4,7 @@ import logging
|
|
4 |
from llama_index import Prompt
|
5 |
from typing import List, Tuple
|
6 |
import mdtex2html
|
|
|
7 |
|
8 |
from modules.presets import *
|
9 |
from modules.llama_func import *
|
@@ -20,23 +21,60 @@ def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[st
|
|
20 |
|
21 |
|
22 |
def postprocess(
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r", encoding="utf-8") as f2:
|
42 |
customJS = f.read()
|
|
|
4 |
from llama_index import Prompt
|
5 |
from typing import List, Tuple
|
6 |
import mdtex2html
|
7 |
+
from gradio_client import utils as client_utils
|
8 |
|
9 |
from modules.presets import *
|
10 |
from modules.llama_func import *
|
|
|
21 |
|
22 |
|
23 |
def postprocess(
|
24 |
+
self,
|
25 |
+
y: List[List[str | Tuple[str] | Tuple[str, str] | None] | Tuple],
|
26 |
+
) -> List[List[str | Dict | None]]:
|
27 |
+
"""
|
28 |
+
Parameters:
|
29 |
+
y: List of lists representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
|
30 |
+
Returns:
|
31 |
+
List of lists representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. Or None if the message is not to be displayed.
|
32 |
+
"""
|
33 |
+
if y is None:
|
34 |
+
return []
|
35 |
+
processed_messages = []
|
36 |
+
for message_pair in y:
|
37 |
+
assert isinstance(
|
38 |
+
message_pair, (tuple, list)
|
39 |
+
), f"Expected a list of lists or list of tuples. Received: {message_pair}"
|
40 |
+
assert (
|
41 |
+
len(message_pair) == 2
|
42 |
+
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
43 |
+
|
44 |
+
processed_messages.append(
|
45 |
+
[
|
46 |
+
self._postprocess_chat_messages(message_pair[0], "user"),
|
47 |
+
self._postprocess_chat_messages(message_pair[1], "bot"),
|
48 |
+
]
|
49 |
+
)
|
50 |
+
return processed_messages
|
51 |
+
|
52 |
+
def postprocess_chat_messages(
|
53 |
+
self, chat_message: str | Tuple | List | None, message_type: str
|
54 |
+
) -> str | Dict | None:
|
55 |
+
if chat_message is None:
|
56 |
+
return None
|
57 |
+
elif isinstance(chat_message, (tuple, list)):
|
58 |
+
filepath = chat_message[0]
|
59 |
+
mime_type = client_utils.get_mimetype(filepath)
|
60 |
+
filepath = self.make_temp_copy_if_needed(filepath)
|
61 |
+
return {
|
62 |
+
"name": filepath,
|
63 |
+
"mime_type": mime_type,
|
64 |
+
"alt_text": chat_message[1] if len(chat_message) > 1 else None,
|
65 |
+
"data": None, # These last two fields are filled in by the frontend
|
66 |
+
"is_file": True,
|
67 |
+
}
|
68 |
+
elif isinstance(chat_message, str):
|
69 |
+
if message_type == "bot":
|
70 |
+
if not detect_converted_mark(chat_message):
|
71 |
+
chat_message = convert_mdtext(chat_message)
|
72 |
+
elif message_type == "user":
|
73 |
+
if not detect_converted_mark(chat_message):
|
74 |
+
chat_message = convert_asis(chat_message)
|
75 |
+
return chat_message
|
76 |
+
else:
|
77 |
+
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
78 |
|
79 |
with open("./assets/custom.js", "r", encoding="utf-8") as f, open("./assets/Kelpy-Codos.js", "r", encoding="utf-8") as f2:
|
80 |
customJS = f.read()
|
modules/presets.py
CHANGED
@@ -1,89 +1,122 @@
|
|
1 |
# -*- coding:utf-8 -*-
|
2 |
-
import
|
3 |
from pathlib import Path
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
# ChatGPT 设置
|
6 |
-
|
7 |
API_HOST = "api.openai.com"
|
8 |
COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
|
9 |
BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
|
10 |
USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
|
11 |
HISTORY_DIR = Path("history")
|
|
|
12 |
TEMPLATES_DIR = "templates"
|
13 |
|
14 |
# 错误信息
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
27 |
HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
|
28 |
CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
|
29 |
|
30 |
SIM_K = 5
|
31 |
INDEX_QUERY_TEMPRATURE = 1.0
|
32 |
|
33 |
-
|
34 |
-
|
35 |
<div align="center" style="margin:16px 0">
|
36 |
|
37 |
由Bilibili [土川虎虎虎](https://space.bilibili.com/29125536) 和 [明昭MZhao](https://space.bilibili.com/24807452)开发
|
38 |
|
39 |
-
访问川虎
|
40 |
|
41 |
-
此App使用 `gpt-3.5-turbo` 大语言模型
|
42 |
</div>
|
43 |
"""
|
44 |
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
"""
|
48 |
|
49 |
-
|
50 |
|
51 |
-
|
52 |
"gpt-3.5-turbo",
|
53 |
"gpt-3.5-turbo-0301",
|
54 |
"gpt-4",
|
55 |
"gpt-4-0314",
|
56 |
"gpt-4-32k",
|
57 |
"gpt-4-32k-0314",
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
"
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
"
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
}
|
86 |
|
|
|
|
|
|
|
|
|
87 |
REPLY_LANGUAGES = [
|
88 |
"简体中文",
|
89 |
"繁體中文",
|
|
|
1 |
# -*- coding:utf-8 -*-
|
2 |
+
import os
|
3 |
from pathlib import Path
|
4 |
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
CHATGLM_MODEL = None
|
8 |
+
CHATGLM_TOKENIZER = None
|
9 |
+
LLAMA_MODEL = None
|
10 |
+
LLAMA_INFERENCER = None
|
11 |
+
|
12 |
# ChatGPT 设置
|
13 |
+
INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
|
14 |
API_HOST = "api.openai.com"
|
15 |
COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
|
16 |
BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
|
17 |
USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
|
18 |
HISTORY_DIR = Path("history")
|
19 |
+
HISTORY_DIR = "history"
|
20 |
TEMPLATES_DIR = "templates"
|
21 |
|
22 |
# 错误信息
|
23 |
+
STANDARD_ERROR_MSG = "☹️发生了错误:" # 错误信息的标准前缀
|
24 |
+
GENERAL_ERROR_MSG = "获取对话时发生错误,请查看后台日志"
|
25 |
+
ERROR_RETRIEVE_MSG = "请检查网络连接,或者API-Key是否有效。"
|
26 |
+
CONNECTION_TIMEOUT_MSG = "连接超时,无法获取对话。" # 连接超时
|
27 |
+
READ_TIMEOUT_MSG = "读取超时,无法获取对话。" # 读取超时
|
28 |
+
PROXY_ERROR_MSG = "代理错误,无法获取对话。" # 代理错误
|
29 |
+
SSL_ERROR_PROMPT = "SSL错误,无法获取对话。" # SSL 错误
|
30 |
+
NO_APIKEY_MSG = "API key为空,请检查是否输入正确。" # API key 长度不足 51 位
|
31 |
+
NO_INPUT_MSG = "请输入对话内容。" # 未输入对话内容
|
32 |
+
BILLING_NOT_APPLICABLE_MSG = "账单信息不适用" # 本地运行的模型返回的账单信息
|
33 |
+
|
34 |
+
TIMEOUT_STREAMING = 60 # 流式对话时的超时时间
|
35 |
+
TIMEOUT_ALL = 200 # 非流式对话时的超时时间
|
36 |
+
ENABLE_STREAMING_OPTION = True # 是否启用选择选择是否实时显示回答的勾选框
|
37 |
HIDE_MY_KEY = False # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
|
38 |
CONCURRENT_COUNT = 100 # 允许同时使用的用户数量
|
39 |
|
40 |
SIM_K = 5
|
41 |
INDEX_QUERY_TEMPRATURE = 1.0
|
42 |
|
43 |
+
CHUANHU_TITLE = """<h1 align="left">川虎Chat 🚀</h1>"""
|
44 |
+
CHUANHU_DESCRIPTION = """\
|
45 |
<div align="center" style="margin:16px 0">
|
46 |
|
47 |
由Bilibili [土川虎虎虎](https://space.bilibili.com/29125536) 和 [明昭MZhao](https://space.bilibili.com/24807452)开发
|
48 |
|
49 |
+
访问川虎Chat的 [GitHub项目](https://github.com/GaiZhenbiao/ChuanhuChatGPT) 下载最新版脚本
|
50 |
|
|
|
51 |
</div>
|
52 |
"""
|
53 |
|
54 |
+
FOOTER = """<div class="versions">{versions}</div>"""
|
55 |
+
|
56 |
+
APPEARANCE_SWITCHER = """
|
57 |
+
<div style="display: flex; justify-content: space-between;">
|
58 |
+
<span style="margin-top: 4px !important;">切换亮暗色主题</span>
|
59 |
+
<span><label class="apSwitch" for="checkbox">
|
60 |
+
<input type="checkbox" id="checkbox">
|
61 |
+
<div class="apSlider"></div>
|
62 |
+
</label></span>
|
63 |
+
</div>
|
64 |
"""
|
65 |
|
66 |
+
SUMMARIZE_PROMPT = "你是谁?我们刚才聊了什么?" # 总结对话时的 prompt
|
67 |
|
68 |
+
ONLINE_MODELS = [
|
69 |
"gpt-3.5-turbo",
|
70 |
"gpt-3.5-turbo-0301",
|
71 |
"gpt-4",
|
72 |
"gpt-4-0314",
|
73 |
"gpt-4-32k",
|
74 |
"gpt-4-32k-0314",
|
75 |
+
"xmbot",
|
76 |
+
]
|
77 |
+
|
78 |
+
LOCAL_MODELS = [
|
79 |
+
"chatglm-6b",
|
80 |
+
"chatglm-6b-int4",
|
81 |
+
"chatglm-6b-int4-qe",
|
82 |
+
"llama-7b-hf",
|
83 |
+
"llama-7b-hf-int4",
|
84 |
+
"llama-7b-hf-int8",
|
85 |
+
"llama-13b-hf",
|
86 |
+
"llama-13b-hf-int4",
|
87 |
+
"llama-30b-hf",
|
88 |
+
"llama-30b-hf-int4",
|
89 |
+
"llama-65b-hf"
|
90 |
+
]
|
91 |
+
|
92 |
+
if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
|
93 |
+
MODELS = ONLINE_MODELS
|
94 |
+
else:
|
95 |
+
MODELS = ONLINE_MODELS + LOCAL_MODELS
|
96 |
+
|
97 |
+
DEFAULT_MODEL = 0
|
98 |
+
|
99 |
+
os.makedirs("models", exist_ok=True)
|
100 |
+
os.makedirs("lora", exist_ok=True)
|
101 |
+
os.makedirs("history", exist_ok=True)
|
102 |
+
for dir_name in os.listdir("models"):
|
103 |
+
if os.path.isdir(os.path.join("models", dir_name)):
|
104 |
+
if dir_name not in MODELS:
|
105 |
+
MODELS.append(dir_name)
|
106 |
+
|
107 |
+
MODEL_TOKEN_LIMIT = {
|
108 |
+
"gpt-3.5-turbo": 4096,
|
109 |
+
"gpt-3.5-turbo-0301": 4096,
|
110 |
+
"gpt-4": 8192,
|
111 |
+
"gpt-4-0314": 8192,
|
112 |
+
"gpt-4-32k": 32768,
|
113 |
+
"gpt-4-32k-0314": 32768
|
114 |
}
|
115 |
|
116 |
+
TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
|
117 |
+
DEFAULT_TOKEN_LIMIT = 3000 # 默认的token上限
|
118 |
+
REDUCE_TOKEN_FACTOR = 0.5 # 与模型token上限想乘,得到目标token数。减少token占用时,将token占用减少到目标token数以下。
|
119 |
+
|
120 |
REPLY_LANGUAGES = [
|
121 |
"简体中文",
|
122 |
"繁體中文",
|
modules/shared.py
CHANGED
@@ -41,11 +41,11 @@ class State:
|
|
41 |
def switching_api_key(self, func):
|
42 |
if not hasattr(self, "api_key_queue"):
|
43 |
return func
|
44 |
-
|
45 |
def wrapped(*args, **kwargs):
|
46 |
api_key = self.api_key_queue.get()
|
47 |
-
args =
|
48 |
-
ret = func(
|
49 |
self.api_key_queue.put(api_key)
|
50 |
return ret
|
51 |
|
|
|
41 |
def switching_api_key(self, func):
|
42 |
if not hasattr(self, "api_key_queue"):
|
43 |
return func
|
44 |
+
|
45 |
def wrapped(*args, **kwargs):
|
46 |
api_key = self.api_key_queue.get()
|
47 |
+
args[0].api_key = api_key
|
48 |
+
ret = func(*args, **kwargs)
|
49 |
self.api_key_queue.put(api_key)
|
50 |
return ret
|
51 |
|
modules/utils.py
CHANGED
@@ -34,6 +34,85 @@ if TYPE_CHECKING:
|
|
34 |
headers: List[str]
|
35 |
data: List[List[str | int | bool]]
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def count_token(message):
|
39 |
encoding = tiktoken.get_encoding("cl100k_base")
|
@@ -121,10 +200,13 @@ def convert_asis(userinput):
|
|
121 |
|
122 |
|
123 |
def detect_converted_mark(userinput):
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
125 |
return True
|
126 |
-
else:
|
127 |
-
return False
|
128 |
|
129 |
|
130 |
def detect_language(code):
|
@@ -153,107 +235,22 @@ def construct_assistant(text):
|
|
153 |
return construct_text("assistant", text)
|
154 |
|
155 |
|
156 |
-
def construct_token_message(tokens: List[int]):
|
157 |
-
token_sum = 0
|
158 |
-
for i in range(len(tokens)):
|
159 |
-
token_sum += sum(tokens[: i + 1])
|
160 |
-
return f"Token 计数: {sum(tokens)},本次对话累计消耗了 {token_sum} tokens"
|
161 |
-
|
162 |
-
|
163 |
-
def delete_first_conversation(history, previous_token_count):
|
164 |
-
if history:
|
165 |
-
del history[:2]
|
166 |
-
del previous_token_count[0]
|
167 |
-
return (
|
168 |
-
history,
|
169 |
-
previous_token_count,
|
170 |
-
construct_token_message(previous_token_count),
|
171 |
-
)
|
172 |
-
|
173 |
-
|
174 |
-
def delete_last_conversation(chatbot, history, previous_token_count):
|
175 |
-
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
176 |
-
logging.info("由于包含报错信息,只删除chatbot记录")
|
177 |
-
chatbot.pop()
|
178 |
-
return chatbot, history
|
179 |
-
if len(history) > 0:
|
180 |
-
logging.info("删除了一组对话历史")
|
181 |
-
history.pop()
|
182 |
-
history.pop()
|
183 |
-
if len(chatbot) > 0:
|
184 |
-
logging.info("删除了一组chatbot对话")
|
185 |
-
chatbot.pop()
|
186 |
-
if len(previous_token_count) > 0:
|
187 |
-
logging.info("删除了一组对话的token计数记录")
|
188 |
-
previous_token_count.pop()
|
189 |
-
return (
|
190 |
-
chatbot,
|
191 |
-
history,
|
192 |
-
previous_token_count,
|
193 |
-
construct_token_message(previous_token_count),
|
194 |
-
)
|
195 |
-
|
196 |
-
|
197 |
def save_file(filename, system, history, chatbot, user_name):
|
198 |
-
logging.
|
199 |
-
os.makedirs(HISTORY_DIR
|
200 |
if filename.endswith(".json"):
|
201 |
json_s = {"system": system, "history": history, "chatbot": chatbot}
|
202 |
print(json_s)
|
203 |
-
with open(os.path.join(HISTORY_DIR
|
204 |
json.dump(json_s, f)
|
205 |
elif filename.endswith(".md"):
|
206 |
md_s = f"system: \n- {system} \n"
|
207 |
for data in history:
|
208 |
md_s += f"\n{data['role']}: \n- {data['content']} \n"
|
209 |
-
with open(os.path.join(HISTORY_DIR
|
210 |
f.write(md_s)
|
211 |
-
logging.
|
212 |
-
return os.path.join(HISTORY_DIR
|
213 |
-
|
214 |
-
|
215 |
-
def save_chat_history(filename, system, history, chatbot, user_name):
|
216 |
-
if filename == "":
|
217 |
-
return
|
218 |
-
if not filename.endswith(".json"):
|
219 |
-
filename += ".json"
|
220 |
-
return save_file(filename, system, history, chatbot, user_name)
|
221 |
-
|
222 |
-
|
223 |
-
def export_markdown(filename, system, history, chatbot, user_name):
|
224 |
-
if filename == "":
|
225 |
-
return
|
226 |
-
if not filename.endswith(".md"):
|
227 |
-
filename += ".md"
|
228 |
-
return save_file(filename, system, history, chatbot, user_name)
|
229 |
-
|
230 |
-
|
231 |
-
def load_chat_history(filename, system, history, chatbot, user_name):
|
232 |
-
logging.info(f"{user_name} 加载对话历史中……")
|
233 |
-
if type(filename) != str:
|
234 |
-
filename = filename.name
|
235 |
-
try:
|
236 |
-
with open(os.path.join(HISTORY_DIR / user_name, filename), "r") as f:
|
237 |
-
json_s = json.load(f)
|
238 |
-
try:
|
239 |
-
if type(json_s["history"][0]) == str:
|
240 |
-
logging.info("历史记录格式为旧版,正在转换……")
|
241 |
-
new_history = []
|
242 |
-
for index, item in enumerate(json_s["history"]):
|
243 |
-
if index % 2 == 0:
|
244 |
-
new_history.append(construct_user(item))
|
245 |
-
else:
|
246 |
-
new_history.append(construct_assistant(item))
|
247 |
-
json_s["history"] = new_history
|
248 |
-
logging.info(new_history)
|
249 |
-
except:
|
250 |
-
# 没有对话历史
|
251 |
-
pass
|
252 |
-
logging.info(f"{user_name} 加载对话历史完毕")
|
253 |
-
return filename, json_s["system"], json_s["history"], json_s["chatbot"]
|
254 |
-
except FileNotFoundError:
|
255 |
-
logging.info(f"{user_name} 没有找到对话历史文件,不执行任何操作")
|
256 |
-
return filename, system, history, chatbot
|
257 |
|
258 |
|
259 |
def sorted_by_pinyin(list):
|
@@ -261,7 +258,7 @@ def sorted_by_pinyin(list):
|
|
261 |
|
262 |
|
263 |
def get_file_names(dir, plain=False, filetypes=[".json"]):
|
264 |
-
logging.
|
265 |
files = []
|
266 |
try:
|
267 |
for type in filetypes:
|
@@ -279,14 +276,13 @@ def get_file_names(dir, plain=False, filetypes=[".json"]):
|
|
279 |
|
280 |
|
281 |
def get_history_names(plain=False, user_name=""):
|
282 |
-
logging.
|
283 |
-
return get_file_names(HISTORY_DIR
|
284 |
|
285 |
|
286 |
def load_template(filename, mode=0):
|
287 |
-
logging.
|
288 |
lines = []
|
289 |
-
logging.info("Loading template...")
|
290 |
if filename.endswith(".json"):
|
291 |
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
|
292 |
lines = json.load(f)
|
@@ -310,23 +306,18 @@ def load_template(filename, mode=0):
|
|
310 |
|
311 |
|
312 |
def get_template_names(plain=False):
|
313 |
-
logging.
|
314 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
315 |
|
316 |
|
317 |
def get_template_content(templates, selection, original_system_prompt):
|
318 |
-
logging.
|
319 |
try:
|
320 |
return templates[selection]
|
321 |
except:
|
322 |
return original_system_prompt
|
323 |
|
324 |
|
325 |
-
def reset_state():
|
326 |
-
logging.info("重置状态")
|
327 |
-
return [], [], [], construct_token_message([0])
|
328 |
-
|
329 |
-
|
330 |
def reset_textbox():
|
331 |
logging.debug("重置文本框")
|
332 |
return gr.update(value="")
|
@@ -388,7 +379,7 @@ def get_geoip():
|
|
388 |
logging.warning(f"无法获取IP地址信息。\n{data}")
|
389 |
if data["reason"] == "RateLimited":
|
390 |
return (
|
391 |
-
f"
|
392 |
)
|
393 |
else:
|
394 |
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
|
@@ -418,7 +409,7 @@ def find_n(lst, max_num):
|
|
418 |
|
419 |
def start_outputing():
|
420 |
logging.debug("显示取消按钮,隐藏发送按钮")
|
421 |
-
return gr.Button.update(visible=
|
422 |
|
423 |
|
424 |
def end_outputing():
|
@@ -440,8 +431,8 @@ def transfer_input(inputs):
|
|
440 |
return (
|
441 |
inputs,
|
442 |
gr.update(value=""),
|
443 |
-
gr.Button.update(visible=True),
|
444 |
gr.Button.update(visible=False),
|
|
|
445 |
)
|
446 |
|
447 |
|
@@ -504,15 +495,15 @@ def add_details(lst):
|
|
504 |
return nodes
|
505 |
|
506 |
|
507 |
-
def sheet_to_string(sheet):
|
508 |
-
result =
|
509 |
for index, row in sheet.iterrows():
|
510 |
row_string = ""
|
511 |
for column in sheet.columns:
|
512 |
row_string += f"{column}: {row[column]}, "
|
513 |
row_string = row_string.rstrip(", ")
|
514 |
row_string += "."
|
515 |
-
result
|
516 |
return result
|
517 |
|
518 |
def excel_to_string(file_path):
|
@@ -520,17 +511,23 @@ def excel_to_string(file_path):
|
|
520 |
excel_file = pd.read_excel(file_path, engine='openpyxl', sheet_name=None)
|
521 |
|
522 |
# 初始化结果字符串
|
523 |
-
result =
|
524 |
|
525 |
# 遍历每一个工作表
|
526 |
for sheet_name, sheet_data in excel_file.items():
|
527 |
-
# 将工作表名称添加到结果字符串
|
528 |
-
result += f"Sheet: {sheet_name}\n"
|
529 |
|
530 |
# 处理当前工作表并添加到结果字符串
|
531 |
-
result += sheet_to_string(sheet_data)
|
532 |
|
533 |
-
# 在不同工作表之间添加分隔符
|
534 |
-
result += "\n" + ("-" * 20) + "\n\n"
|
535 |
|
536 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
headers: List[str]
|
35 |
data: List[List[str | int | bool]]
|
36 |
|
37 |
+
def predict(current_model, *args):
|
38 |
+
iter = current_model.predict(*args)
|
39 |
+
for i in iter:
|
40 |
+
yield i
|
41 |
+
|
42 |
+
def billing_info(current_model):
|
43 |
+
return current_model.billing_info()
|
44 |
+
|
45 |
+
def set_key(current_model, *args):
|
46 |
+
return current_model.set_key(*args)
|
47 |
+
|
48 |
+
def load_chat_history(current_model, *args):
|
49 |
+
return current_model.load_chat_history(*args)
|
50 |
+
|
51 |
+
def interrupt(current_model, *args):
|
52 |
+
return current_model.interrupt(*args)
|
53 |
+
|
54 |
+
def reset(current_model, *args):
|
55 |
+
return current_model.reset(*args)
|
56 |
+
|
57 |
+
def retry(current_model, *args):
|
58 |
+
iter = current_model.retry(*args)
|
59 |
+
for i in iter:
|
60 |
+
yield i
|
61 |
+
|
62 |
+
def delete_first_conversation(current_model, *args):
|
63 |
+
return current_model.delete_first_conversation(*args)
|
64 |
+
|
65 |
+
def delete_last_conversation(current_model, *args):
|
66 |
+
return current_model.delete_last_conversation(*args)
|
67 |
+
|
68 |
+
def set_system_prompt(current_model, *args):
|
69 |
+
return current_model.set_system_prompt(*args)
|
70 |
+
|
71 |
+
def save_chat_history(current_model, *args):
|
72 |
+
return current_model.save_chat_history(*args)
|
73 |
+
|
74 |
+
def export_markdown(current_model, *args):
|
75 |
+
return current_model.export_markdown(*args)
|
76 |
+
|
77 |
+
def load_chat_history(current_model, *args):
|
78 |
+
return current_model.load_chat_history(*args)
|
79 |
+
|
80 |
+
def set_token_upper_limit(current_model, *args):
|
81 |
+
return current_model.set_token_upper_limit(*args)
|
82 |
+
|
83 |
+
def set_temperature(current_model, *args):
|
84 |
+
current_model.set_temperature(*args)
|
85 |
+
|
86 |
+
def set_top_p(current_model, *args):
|
87 |
+
current_model.set_top_p(*args)
|
88 |
+
|
89 |
+
def set_n_choices(current_model, *args):
|
90 |
+
current_model.set_n_choices(*args)
|
91 |
+
|
92 |
+
def set_stop_sequence(current_model, *args):
|
93 |
+
current_model.set_stop_sequence(*args)
|
94 |
+
|
95 |
+
def set_max_tokens(current_model, *args):
|
96 |
+
current_model.set_max_tokens(*args)
|
97 |
+
|
98 |
+
def set_presence_penalty(current_model, *args):
|
99 |
+
current_model.set_presence_penalty(*args)
|
100 |
+
|
101 |
+
def set_frequency_penalty(current_model, *args):
|
102 |
+
current_model.set_frequency_penalty(*args)
|
103 |
+
|
104 |
+
def set_logit_bias(current_model, *args):
|
105 |
+
current_model.set_logit_bias(*args)
|
106 |
+
|
107 |
+
def set_user_identifier(current_model, *args):
|
108 |
+
current_model.set_user_identifier(*args)
|
109 |
+
|
110 |
+
def set_single_turn(current_model, *args):
|
111 |
+
current_model.set_single_turn(*args)
|
112 |
+
|
113 |
+
def handle_file_upload(current_model, *args):
|
114 |
+
return current_model.handle_file_upload(*args)
|
115 |
+
|
116 |
|
117 |
def count_token(message):
|
118 |
encoding = tiktoken.get_encoding("cl100k_base")
|
|
|
200 |
|
201 |
|
202 |
def detect_converted_mark(userinput):
|
203 |
+
try:
|
204 |
+
if userinput.endswith(ALREADY_CONVERTED_MARK):
|
205 |
+
return True
|
206 |
+
else:
|
207 |
+
return False
|
208 |
+
except:
|
209 |
return True
|
|
|
|
|
210 |
|
211 |
|
212 |
def detect_language(code):
|
|
|
235 |
return construct_text("assistant", text)
|
236 |
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def save_file(filename, system, history, chatbot, user_name):
|
239 |
+
logging.debug(f"{user_name} 保存对话历史中……")
|
240 |
+
os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True)
|
241 |
if filename.endswith(".json"):
|
242 |
json_s = {"system": system, "history": history, "chatbot": chatbot}
|
243 |
print(json_s)
|
244 |
+
with open(os.path.join(HISTORY_DIR, user_name, filename), "w") as f:
|
245 |
json.dump(json_s, f)
|
246 |
elif filename.endswith(".md"):
|
247 |
md_s = f"system: \n- {system} \n"
|
248 |
for data in history:
|
249 |
md_s += f"\n{data['role']}: \n- {data['content']} \n"
|
250 |
+
with open(os.path.join(HISTORY_DIR, user_name, filename), "w", encoding="utf8") as f:
|
251 |
f.write(md_s)
|
252 |
+
logging.debug(f"{user_name} 保存对话历史完毕")
|
253 |
+
return os.path.join(HISTORY_DIR, user_name, filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
|
256 |
def sorted_by_pinyin(list):
|
|
|
258 |
|
259 |
|
260 |
def get_file_names(dir, plain=False, filetypes=[".json"]):
|
261 |
+
logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
|
262 |
files = []
|
263 |
try:
|
264 |
for type in filetypes:
|
|
|
276 |
|
277 |
|
278 |
def get_history_names(plain=False, user_name=""):
|
279 |
+
logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表")
|
280 |
+
return get_file_names(os.path.join(HISTORY_DIR, user_name), plain)
|
281 |
|
282 |
|
283 |
def load_template(filename, mode=0):
|
284 |
+
logging.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
|
285 |
lines = []
|
|
|
286 |
if filename.endswith(".json"):
|
287 |
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
|
288 |
lines = json.load(f)
|
|
|
306 |
|
307 |
|
308 |
def get_template_names(plain=False):
|
309 |
+
logging.debug("获取模板文件名列表")
|
310 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
311 |
|
312 |
|
313 |
def get_template_content(templates, selection, original_system_prompt):
|
314 |
+
logging.debug(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
|
315 |
try:
|
316 |
return templates[selection]
|
317 |
except:
|
318 |
return original_system_prompt
|
319 |
|
320 |
|
|
|
|
|
|
|
|
|
|
|
321 |
def reset_textbox():
|
322 |
logging.debug("重置文本框")
|
323 |
return gr.update(value="")
|
|
|
379 |
logging.warning(f"无法获取IP地址信息。\n{data}")
|
380 |
if data["reason"] == "RateLimited":
|
381 |
return (
|
382 |
+
f"您的IP区域:未知。"
|
383 |
)
|
384 |
else:
|
385 |
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
|
|
|
409 |
|
410 |
def start_outputing():
|
411 |
logging.debug("显示取消按钮,隐藏发送按钮")
|
412 |
+
return gr.Button.update(visible=False), gr.Button.update(visible=True)
|
413 |
|
414 |
|
415 |
def end_outputing():
|
|
|
431 |
return (
|
432 |
inputs,
|
433 |
gr.update(value=""),
|
|
|
434 |
gr.Button.update(visible=False),
|
435 |
+
gr.Button.update(visible=True),
|
436 |
)
|
437 |
|
438 |
|
|
|
495 |
return nodes
|
496 |
|
497 |
|
498 |
+
def sheet_to_string(sheet, sheet_name = None):
|
499 |
+
result = []
|
500 |
for index, row in sheet.iterrows():
|
501 |
row_string = ""
|
502 |
for column in sheet.columns:
|
503 |
row_string += f"{column}: {row[column]}, "
|
504 |
row_string = row_string.rstrip(", ")
|
505 |
row_string += "."
|
506 |
+
result.append(row_string)
|
507 |
return result
|
508 |
|
509 |
def excel_to_string(file_path):
|
|
|
511 |
excel_file = pd.read_excel(file_path, engine='openpyxl', sheet_name=None)
|
512 |
|
513 |
# 初始化结果字符串
|
514 |
+
result = []
|
515 |
|
516 |
# 遍历每一个工作表
|
517 |
for sheet_name, sheet_data in excel_file.items():
|
|
|
|
|
518 |
|
519 |
# 处理当前工作表并添加到结果字符串
|
520 |
+
result += sheet_to_string(sheet_data, sheet_name=sheet_name)
|
521 |
|
|
|
|
|
522 |
|
523 |
return result
|
524 |
+
|
525 |
+
def get_last_day_of_month(any_day):
|
526 |
+
# The day 28 exists in every month. 4 days later, it's always next month
|
527 |
+
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
528 |
+
# subtracting the number of the current day brings us back one month
|
529 |
+
return next_month - datetime.timedelta(days=next_month.day)
|
530 |
+
|
531 |
+
def get_model_source(model_name, alternative_source):
|
532 |
+
if model_name == "gpt2-medium":
|
533 |
+
return "https://huggingface.co/gpt2-medium"
|
requirements.txt
CHANGED
@@ -13,3 +13,4 @@ markdown
|
|
13 |
PyPDF2
|
14 |
pdfplumber
|
15 |
pandas
|
|
|
|
13 |
PyPDF2
|
14 |
pdfplumber
|
15 |
pandas
|
16 |
+
commentjson
|
requirements_advanced.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
icetk
|
4 |
+
protobuf==3.19.0
|
5 |
+
git+https://github.com/OptimalScale/LMFlow.git
|
6 |
+
cpm-kernels
|
7 |
+
sentence_transformers
|
run_Linux.sh
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
#!/bin/bash
|
2 |
|
3 |
# 获取脚本所在目录
|
4 |
-
script_dir=$(dirname "$0")
|
5 |
|
6 |
# 将工作目录更改为脚本所在目录
|
7 |
-
cd "$script_dir"
|
8 |
|
9 |
# 检查Git仓库是否有更新
|
10 |
git remote update
|
@@ -23,3 +23,9 @@ if ! git status -uno | grep 'up to date' > /dev/null; then
|
|
23 |
# 重新启动服务器
|
24 |
nohup python3 ChuanhuChatbot.py &
|
25 |
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
#!/bin/bash
|
2 |
|
3 |
# 获取脚本所在目录
|
4 |
+
script_dir=$(dirname "$(readlink -f "$0")")
|
5 |
|
6 |
# 将工作目录更改为脚本所在目录
|
7 |
+
cd "$script_dir" || exit
|
8 |
|
9 |
# 检查Git仓库是否有更新
|
10 |
git remote update
|
|
|
23 |
# 重新启动服务器
|
24 |
nohup python3 ChuanhuChatbot.py &
|
25 |
fi
|
26 |
+
|
27 |
+
# 检查ChuanhuChatbot.py是否在运行
|
28 |
+
if ! pgrep -f ChuanhuChatbot.py > /dev/null; then
|
29 |
+
# 如果没有运行,启动服务器
|
30 |
+
nohup python3 ChuanhuChatbot.py &
|
31 |
+
fi
|
run_macOS.command
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
#!/bin/bash
|
2 |
|
3 |
# 获取脚本所在目录
|
4 |
-
script_dir=$(dirname "$0")
|
5 |
|
6 |
# 将工作目录更改为脚本所在目录
|
7 |
-
cd "$script_dir"
|
8 |
|
9 |
# 检查Git仓库是否有更新
|
10 |
git remote update
|
@@ -23,3 +23,9 @@ if ! git status -uno | grep 'up to date' > /dev/null; then
|
|
23 |
# 重新启动服务器
|
24 |
nohup python3 ChuanhuChatbot.py &
|
25 |
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
#!/bin/bash
|
2 |
|
3 |
# 获取脚本所在目录
|
4 |
+
script_dir=$(dirname "$(readlink -f "$0")")
|
5 |
|
6 |
# 将工作目录更改为脚本所在目录
|
7 |
+
cd "$script_dir" || exit
|
8 |
|
9 |
# 检查Git仓库是否有更新
|
10 |
git remote update
|
|
|
23 |
# 重新启动服务器
|
24 |
nohup python3 ChuanhuChatbot.py &
|
25 |
fi
|
26 |
+
|
27 |
+
# 检查ChuanhuChatbot.py是否在运行
|
28 |
+
if ! pgrep -f ChuanhuChatbot.py > /dev/null; then
|
29 |
+
# 如果没有运行,启动服务器
|
30 |
+
nohup python3 ChuanhuChatbot.py &
|
31 |
+
fi
|