Tuchuanhuhuhu commited on
Commit
03db1ff
1 Parent(s): 08b7713

feat: 对话历史按时间排序

Browse files
.gitignore CHANGED
@@ -140,7 +140,7 @@ dmypy.json
140
  api_key.txt
141
  config.json
142
  auth.json
143
- models/
144
  lora/
145
  .idea
146
  templates/*
 
140
  api_key.txt
141
  config.json
142
  auth.json
143
+ .models/
144
  lora/
145
  .idea
146
  templates/*
ChuanhuChatbot.py CHANGED
@@ -1,7 +1,7 @@
1
  # -*- coding:utf-8 -*-
2
  import logging
3
  logging.basicConfig(
4
- level=logging.INFO,
5
  format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
6
  )
7
 
@@ -31,7 +31,7 @@ def create_new_model():
31
 
32
  with gr.Blocks(theme=small_and_beautiful_theme) as demo:
33
  user_name = gr.State("")
34
- promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
35
  user_question = gr.State("")
36
  assert type(my_api_key)==str
37
  user_api_key = gr.State(my_api_key)
@@ -135,9 +135,9 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
135
  with gr.Column(scale=6):
136
  templateFileSelectDropdown = gr.Dropdown(
137
  label=i18n("选择Prompt模板集合文件"),
138
- choices=get_template_names(plain=True),
139
  multiselect=False,
140
- value=get_template_names(plain=True)[0],
141
  container=False,
142
  )
143
  with gr.Column(scale=1):
@@ -147,7 +147,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
147
  templateSelectDropdown = gr.Dropdown(
148
  label=i18n("从Prompt模板中加载"),
149
  choices=load_template(
150
- get_template_names(plain=True)[0], mode=1
151
  ),
152
  multiselect=False,
153
  container=False,
@@ -160,7 +160,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
160
  with gr.Column(scale=6):
161
  historyFileSelectDropdown = gr.Dropdown(
162
  label=i18n("从列表中加载对话"),
163
- choices=get_history_names(plain=True),
164
  multiselect=False,
165
  container=False,
166
  )
@@ -185,7 +185,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
185
  gr.Markdown(i18n("默认保存于history文件夹"))
186
  with gr.Row():
187
  with gr.Column():
188
- downloadFile = gr.File(interactive=True)
189
 
190
  with gr.Tab(label=i18n("微调")):
191
  openai_train_status = gr.Markdown(label=i18n("训练状态"), value=i18n("在这里[查看使用介绍](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/%E4%BD%BF%E7%94%A8%E6%95%99%E7%A8%8B#%E5%BE%AE%E8%B0%83-gpt-35)"))
@@ -336,7 +336,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
336
  current_model = get_model(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key)[0]
337
  current_model.set_user_identifier(user_name)
338
  chatbot = gr.Chatbot.update(label=MODELS[DEFAULT_MODEL])
339
- return user_info, user_name, current_model, toggle_like_btn_visibility(DEFAULT_MODEL), *current_model.auto_load(), get_history_names(False, user_name), chatbot
340
  demo.load(create_greeting, inputs=None, outputs=[user_info, user_name, current_model, like_dislike_area, systemPromptTxt, chatbot, historyFileSelectDropdown, chatbot], api_name="load")
341
  chatgpt_predict_args = dict(
342
  fn=predict,
@@ -383,7 +383,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
383
  )
384
 
385
  refresh_history_args = dict(
386
- fn=get_history_names, inputs=[gr.State(False), user_name], outputs=[historyFileSelectDropdown]
387
  )
388
 
389
 
@@ -461,8 +461,8 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
461
 
462
  # Template
463
  systemPromptTxt.change(set_system_prompt, [current_model, systemPromptTxt], None)
464
- templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
465
- templateFileSelectDropdown.change(
466
  load_template,
467
  [templateFileSelectDropdown],
468
  [promptTemplates, templateSelectDropdown],
@@ -482,7 +482,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
482
  downloadFile,
483
  show_progress=True,
484
  )
485
- saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
486
  exportMarkdownBtn.click(
487
  export_markdown,
488
  [current_model, saveFileName, chatbot, user_name],
 
1
  # -*- coding:utf-8 -*-
2
  import logging
3
  logging.basicConfig(
4
+ level=logging.DEBUG,
5
  format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
6
  )
7
 
 
31
 
32
  with gr.Blocks(theme=small_and_beautiful_theme) as demo:
33
  user_name = gr.State("")
34
+ promptTemplates = gr.State(load_template(get_template_names()[0], mode=2))
35
  user_question = gr.State("")
36
  assert type(my_api_key)==str
37
  user_api_key = gr.State(my_api_key)
 
135
  with gr.Column(scale=6):
136
  templateFileSelectDropdown = gr.Dropdown(
137
  label=i18n("选择Prompt模板集合文件"),
138
+ choices=get_template_names(),
139
  multiselect=False,
140
+ value=get_template_names()[0],
141
  container=False,
142
  )
143
  with gr.Column(scale=1):
 
147
  templateSelectDropdown = gr.Dropdown(
148
  label=i18n("从Prompt模板中加载"),
149
  choices=load_template(
150
+ get_template_names()[0], mode=1
151
  ),
152
  multiselect=False,
153
  container=False,
 
160
  with gr.Column(scale=6):
161
  historyFileSelectDropdown = gr.Dropdown(
162
  label=i18n("从列表中加载对话"),
163
+ choices=get_history_names(),
164
  multiselect=False,
165
  container=False,
166
  )
 
185
  gr.Markdown(i18n("默认保存于history文件夹"))
186
  with gr.Row():
187
  with gr.Column():
188
+ downloadFile = gr.File(interactive=True, label="下载/上传历史记录")
189
 
190
  with gr.Tab(label=i18n("微调")):
191
  openai_train_status = gr.Markdown(label=i18n("训练状态"), value=i18n("在这里[查看使用介绍](https://github.com/GaiZhenbiao/ChuanhuChatGPT/wiki/%E4%BD%BF%E7%94%A8%E6%95%99%E7%A8%8B#%E5%BE%AE%E8%B0%83-gpt-35)"))
 
336
  current_model = get_model(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key)[0]
337
  current_model.set_user_identifier(user_name)
338
  chatbot = gr.Chatbot.update(label=MODELS[DEFAULT_MODEL])
339
+ return user_info, user_name, current_model, toggle_like_btn_visibility(DEFAULT_MODEL), *current_model.auto_load(), get_history_dropdown(user_name), chatbot
340
  demo.load(create_greeting, inputs=None, outputs=[user_info, user_name, current_model, like_dislike_area, systemPromptTxt, chatbot, historyFileSelectDropdown, chatbot], api_name="load")
341
  chatgpt_predict_args = dict(
342
  fn=predict,
 
383
  )
384
 
385
  refresh_history_args = dict(
386
+ fn=get_history_dropdown, inputs=[user_name], outputs=[historyFileSelectDropdown]
387
  )
388
 
389
 
 
461
 
462
  # Template
463
  systemPromptTxt.change(set_system_prompt, [current_model, systemPromptTxt], None)
464
+ templateRefreshBtn.click(get_template_dropdown, None, [templateFileSelectDropdown])
465
+ templateFileSelectDropdown.input(
466
  load_template,
467
  [templateFileSelectDropdown],
468
  [promptTemplates, templateSelectDropdown],
 
482
  downloadFile,
483
  show_progress=True,
484
  )
485
+ saveHistoryBtn.click(get_history_dropdown, [user_name], [historyFileSelectDropdown])
486
  exportMarkdownBtn.click(
487
  export_markdown,
488
  [current_model, saveFileName, chatbot, user_name],
modules/models/base_model.py CHANGED
@@ -724,7 +724,7 @@ class BaseLLMModel:
724
  history_file_path = filename
725
  try:
726
  os.remove(history_file_path)
727
- return i18n("删除对话历史成功"), get_history_names(False, user_name), []
728
  except:
729
  logging.info(f"删除对话历史失败 {history_file_path}")
730
  return i18n("对话历史")+filename+i18n("已经被删除啦"), gr.update(), gr.update()
 
724
  history_file_path = filename
725
  try:
726
  os.remove(history_file_path)
727
+ return i18n("删除对话历史成功"), get_history_dropdown(user_name), []
728
  except:
729
  logging.info(f"删除对话历史失败 {history_file_path}")
730
  return i18n("对话历史")+filename+i18n("已经被删除啦"), gr.update(), gr.update()
modules/models/models.py CHANGED
@@ -580,8 +580,8 @@ def get_model(
580
  logging.info(msg)
581
  lora_selector_visibility = True
582
  if os.path.isdir("lora"):
583
- lora_choices = get_file_names(
584
- "lora", plain=True, filetypes=[""])
585
  lora_choices = ["No LoRA"] + lora_choices
586
  elif model_type == ModelType.LLaMA and lora_model_path != "":
587
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
 
580
  logging.info(msg)
581
  lora_selector_visibility = True
582
  if os.path.isdir("lora"):
583
+ lora_choices = get_file_names_dropdown_by_pinyin(
584
+ "lora", filetypes=[""])
585
  lora_choices = ["No LoRA"] + lora_choices
586
  elif model_type == ModelType.LLaMA and lora_model_path != "":
587
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
modules/models/spark.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import _thread as thread
2
+ import base64
3
+ import datetime
4
+ import hashlib
5
+ import hmac
6
+ import json
7
+ from collections import deque
8
+ from urllib.parse import urlparse
9
+ import ssl
10
+ from datetime import datetime
11
+ from time import mktime
12
+ from urllib.parse import urlencode
13
+ from wsgiref.handlers import format_date_time
14
+ from threading import Condition
15
+ import websocket
16
+ import logging
17
+
18
+ from .base_model import BaseLLMModel, CallbackToIterator
19
+
20
+
21
+ class Ws_Param(object):
22
+ # 来自官方 Demo
23
+ # 初始化
24
+ def __init__(self, APPID, APIKey, APISecret, Spark_url):
25
+ self.APPID = APPID
26
+ self.APIKey = APIKey
27
+ self.APISecret = APISecret
28
+ self.host = urlparse(Spark_url).netloc
29
+ self.path = urlparse(Spark_url).path
30
+ self.Spark_url = Spark_url
31
+
32
+ # 生成url
33
+ def create_url(self):
34
+ # 生成RFC1123格式的时间戳
35
+ now = datetime.now()
36
+ date = format_date_time(mktime(now.timetuple()))
37
+
38
+ # 拼接字符串
39
+ signature_origin = "host: " + self.host + "\n"
40
+ signature_origin += "date: " + date + "\n"
41
+ signature_origin += "GET " + self.path + " HTTP/1.1"
42
+
43
+ # 进行hmac-sha256进行加密
44
+ signature_sha = hmac.new(
45
+ self.APISecret.encode("utf-8"),
46
+ signature_origin.encode("utf-8"),
47
+ digestmod=hashlib.sha256,
48
+ ).digest()
49
+
50
+ signature_sha_base64 = base64.b64encode(
51
+ signature_sha).decode(encoding="utf-8")
52
+
53
+ authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
54
+
55
+ authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
56
+ encoding="utf-8"
57
+ )
58
+
59
+ # 将请求的鉴权参数组合为字典
60
+ v = {"authorization": authorization, "date": date, "host": self.host}
61
+ # 拼接鉴权参数,生成url
62
+ url = self.Spark_url + "?" + urlencode(v)
63
+ # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
64
+ return url
65
+
66
+
67
+ class Spark_Client(BaseLLMModel):
68
+ def __init__(self, model_name, appid, api_key, api_secret, user_name="") -> None:
69
+ super().__init__(model_name=model_name, user=user_name)
70
+ self.api_key = api_key
71
+ self.appid = appid
72
+ self.api_secret = api_secret
73
+ if None in [self.api_key, self.appid, self.api_secret]:
74
+ raise Exception("请在配置文件或者环境变量中设置讯飞的API Key、APP ID和API Secret")
75
+ if "2.0" in self.model_name:
76
+ self.spark_url = "wss://spark-api.xf-yun.com/v2.1/chat"
77
+ self.domain = "generalv2"
78
+ else:
79
+ self.spark_url = "wss://spark-api.xf-yun.com/v1.1/chat"
80
+ self.domain = "general"
81
+
82
+ # 收到websocket错误的处理
83
+ def on_error(self, ws, error):
84
+ ws.iterator.callback("出现了错误:" + error)
85
+
86
+ # 收到websocket关闭的处理
87
+ def on_close(self, ws, one, two):
88
+ pass
89
+
90
+ # 收到websocket连接建立的处理
91
+ def on_open(self, ws):
92
+ thread.start_new_thread(self.run, (ws,))
93
+
94
+ def run(self, ws, *args):
95
+ data = json.dumps(
96
+ self.gen_params()
97
+ )
98
+ ws.send(data)
99
+
100
+ # 收到websocket消息的处理
101
+ def on_message(self, ws, message):
102
+ ws.iterator.callback(message)
103
+
104
+ def gen_params(self):
105
+ """
106
+ 通过appid和用户的提问来生成请参数
107
+ """
108
+ data = {
109
+ "header": {"app_id": self.appid, "uid": "1234"},
110
+ "parameter": {
111
+ "chat": {
112
+ "domain": self.domain,
113
+ "random_threshold": self.temperature,
114
+ "max_tokens": 4096,
115
+ "auditing": "default",
116
+ }
117
+ },
118
+ "payload": {"message": {"text": self.history}},
119
+ }
120
+ return data
121
+
122
+ def get_answer_stream_iter(self):
123
+ wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, self.spark_url)
124
+ websocket.enableTrace(False)
125
+ wsUrl = wsParam.create_url()
126
+ ws = websocket.WebSocketApp(
127
+ wsUrl,
128
+ on_message=self.on_message,
129
+ on_error=self.on_error,
130
+ on_close=self.on_close,
131
+ on_open=self.on_open,
132
+ )
133
+ ws.appid = self.appid
134
+ ws.domain = self.domain
135
+
136
+ # Initialize the CallbackToIterator
137
+ ws.iterator = CallbackToIterator()
138
+
139
+ # Start the WebSocket connection in a separate thread
140
+ thread.start_new_thread(
141
+ ws.run_forever, (), {"sslopt": {"cert_reqs": ssl.CERT_NONE}}
142
+ )
143
+
144
+ # Iterate over the CallbackToIterator instance
145
+ answer = ""
146
+ total_tokens = 0
147
+ for message in ws.iterator:
148
+ data = json.loads(message)
149
+ code = data["header"]["code"]
150
+ if code != 0:
151
+ ws.close()
152
+ raise Exception(f"请求错误: {code}, {data}")
153
+ else:
154
+ choices = data["payload"]["choices"]
155
+ status = choices["status"]
156
+ content = choices["text"][0]["content"]
157
+ if "usage" in data["payload"]:
158
+ total_tokens = data["payload"]["usage"]["text"]["total_tokens"]
159
+ answer += content
160
+ if status == 2:
161
+ ws.iterator.finish() # Finish the iterator when the status is 2
162
+ ws.close()
163
+ yield answer, total_tokens
modules/utils.py CHANGED
@@ -354,31 +354,50 @@ def save_file(filename, system, history, chatbot, user_name):
354
  def sorted_by_pinyin(list):
355
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
356
 
 
 
357
 
358
- def get_file_names(dir, plain=False, filetypes=[".json"]):
359
- logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
360
  files = []
361
  try:
362
  for type in filetypes:
363
  files += [f for f in os.listdir(dir) if f.endswith(type)]
364
  except FileNotFoundError:
365
- files = []
366
- files = sorted_by_pinyin(files)
367
- if files == []:
368
  files = [""]
369
  logging.debug(f"files are:{files}")
370
- if plain:
371
- return files
372
- else:
373
- return gr.Dropdown.update(choices=files)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
 
376
- def get_history_names(plain=False, user_name=""):
377
  logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表")
378
  if user_name == "" and hide_history_when_not_logged_in:
379
  return ""
380
  else:
381
- return get_file_names(os.path.join(HISTORY_DIR, user_name), plain)
 
 
 
 
 
382
 
383
 
384
  def load_template(filename, mode=0):
@@ -406,9 +425,14 @@ def load_template(filename, mode=0):
406
  )
407
 
408
 
409
- def get_template_names(plain=False):
410
  logging.debug("获取模板文件名列表")
411
- return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
 
 
 
 
 
412
 
413
 
414
  def get_template_content(templates, selection, original_system_prompt):
 
354
  def sorted_by_pinyin(list):
355
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
356
 
357
+ def sorted_by_last_modified_time(list, dir):
358
+ return sorted(list, key=lambda char: os.path.getmtime(os.path.join(dir, char)), reverse=True)
359
 
360
+ def get_file_names_by_type(dir, filetypes=[".json"]):
361
+ logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes}")
362
  files = []
363
  try:
364
  for type in filetypes:
365
  files += [f for f in os.listdir(dir) if f.endswith(type)]
366
  except FileNotFoundError:
 
 
 
367
  files = [""]
368
  logging.debug(f"files are:{files}")
369
+ return files
370
+
371
+ def get_file_names_by_pinyin(dir, filetypes=[".json"]):
372
+ files = get_file_names_by_type(dir, filetypes)
373
+ if files != [""]:
374
+ files = sorted_by_pinyin(files)
375
+ logging.debug(f"files are:{files}")
376
+ return files
377
+
378
+ def get_file_names_dropdown_by_pinyin(dir, filetypes=[".json"]):
379
+ files = get_file_names_by_pinyin(dir, filetypes)
380
+ return gr.Dropdown.update(choices=files)
381
+
382
+ def get_file_names_by_last_modified_time(dir, filetypes=[".json"]):
383
+ files = get_file_names_by_type(dir, filetypes)
384
+ if files != [""]:
385
+ files = sorted_by_last_modified_time(files, dir)
386
+ logging.debug(f"files are:{files}")
387
+ return files
388
 
389
 
390
+ def get_history_names(user_name=""):
391
  logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表")
392
  if user_name == "" and hide_history_when_not_logged_in:
393
  return ""
394
  else:
395
+ history_files = get_file_names_by_last_modified_time(os.path.join(HISTORY_DIR, user_name))
396
+ return history_files
397
+
398
+ def get_history_dropdown(user_name=""):
399
+ history_names = get_history_names(user_name)
400
+ return gr.Dropdown.update(choices=history_names)
401
 
402
 
403
  def load_template(filename, mode=0):
 
425
  )
426
 
427
 
428
+ def get_template_names():
429
  logging.debug("获取模板文件名列表")
430
+ return get_file_names_by_pinyin(TEMPLATES_DIR, filetypes=[".csv", "json"])
431
+
432
+ def get_template_dropdown():
433
+ logging.debug("获取模板下拉菜单")
434
+ template_names = get_template_names()
435
+ return gr.Dropdown.update(choices=template_names)
436
 
437
 
438
  def get_template_content(templates, selection, original_system_prompt):