Keldos commited on
Commit
601c367
2 Parent(s): 3290e22 d55a9fe

Merge branch 'main' into UI-new

Browse files
.gitignore CHANGED
@@ -140,7 +140,7 @@ dmypy.json
140
  api_key.txt
141
  config.json
142
  auth.json
143
- models/
144
  lora/
145
  .idea
146
  templates/*
 
140
  api_key.txt
141
  config.json
142
  auth.json
143
+ .models/
144
  lora/
145
  .idea
146
  templates/*
ChuanhuChatbot.py CHANGED
@@ -31,7 +31,7 @@ def create_new_model():
31
 
32
  with gr.Blocks(theme=small_and_beautiful_theme) as demo:
33
  user_name = gr.State("")
34
- promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2))
35
  user_question = gr.State("")
36
  assert type(my_api_key)==str
37
  user_api_key = gr.State(my_api_key)
@@ -64,7 +64,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
64
  with gr.Column(scale=6, elem_id="history-select-wrap"):
65
  historyFileSelectDropdown = gr.Radio(
66
  label=i18n("从列表中加载对话"),
67
- choices=get_history_names(plain=True),
68
  # multiselect=False,
69
  container=False,
70
  elem_id="history-select-dropdown"
@@ -90,7 +90,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
90
  gr.Markdown(i18n("默认保存于history文件夹"))
91
  with gr.Row():
92
  with gr.Column():
93
- downloadFile = gr.File(interactive=True)
94
 
95
  with gr.Column(elem_id="chuanhu-menu-footer"):
96
  with gr.Row(elem_id="chuanhu-func-nav"):
@@ -179,9 +179,9 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
179
  with gr.Column(scale=6):
180
  templateFileSelectDropdown = gr.Dropdown(
181
  label=i18n("选择Prompt模板集合文件"),
182
- choices=get_template_names(plain=True),
183
  multiselect=False,
184
- value=get_template_names(plain=True)[0],
185
  container=False,
186
  )
187
  with gr.Column(scale=1):
@@ -191,7 +191,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
191
  templateSelectDropdown = gr.Dropdown(
192
  label=i18n("从Prompt模板中加载"),
193
  choices=load_template(
194
- get_template_names(plain=True)[0], mode=1
195
  ),
196
  multiselect=False,
197
  container=False,
@@ -399,7 +399,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
399
  current_model = get_model(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key)[0]
400
  current_model.set_user_identifier(user_name)
401
  chatbot = gr.Chatbot.update(label=MODELS[DEFAULT_MODEL])
402
- return user_info, user_name, current_model, toggle_like_btn_visibility(DEFAULT_MODEL), *current_model.auto_load(), get_history_names(False, user_name), chatbot
403
  demo.load(create_greeting, inputs=None, outputs=[user_info, user_name, current_model, like_dislike_area, systemPromptTxt, chatbot, historyFileSelectDropdown, chatbot], api_name="load")
404
  chatgpt_predict_args = dict(
405
  fn=predict,
@@ -446,7 +446,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
446
  )
447
 
448
  refresh_history_args = dict(
449
- fn=get_history_names, inputs=[gr.State(False), user_name], outputs=[historyFileSelectDropdown]
450
  )
451
 
452
 
@@ -524,8 +524,8 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
524
 
525
  # Template
526
  systemPromptTxt.change(set_system_prompt, [current_model, systemPromptTxt], None)
527
- templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
528
- templateFileSelectDropdown.change(
529
  load_template,
530
  [templateFileSelectDropdown],
531
  [promptTemplates, templateSelectDropdown],
@@ -545,7 +545,7 @@ with gr.Blocks(theme=small_and_beautiful_theme) as demo:
545
  downloadFile,
546
  show_progress=True,
547
  )
548
- saveHistoryBtn.click(get_history_names, [gr.State(False), user_name], [historyFileSelectDropdown])
549
  exportMarkdownBtn.click(
550
  export_markdown,
551
  [current_model, saveFileName, chatbot, user_name],
 
31
 
32
  with gr.Blocks(theme=small_and_beautiful_theme) as demo:
33
  user_name = gr.State("")
34
+ promptTemplates = gr.State(load_template(get_template_names()[0], mode=2))
35
  user_question = gr.State("")
36
  assert type(my_api_key)==str
37
  user_api_key = gr.State(my_api_key)
 
64
  with gr.Column(scale=6, elem_id="history-select-wrap"):
65
  historyFileSelectDropdown = gr.Radio(
66
  label=i18n("从列表中加载对话"),
67
+ choices=get_history_names(),
68
  # multiselect=False,
69
  container=False,
70
  elem_id="history-select-dropdown"
 
90
  gr.Markdown(i18n("默认保存于history文件夹"))
91
  with gr.Row():
92
  with gr.Column():
93
+ downloadFile = gr.File(interactive=True, label="下载/上传历史记录")
94
 
95
  with gr.Column(elem_id="chuanhu-menu-footer"):
96
  with gr.Row(elem_id="chuanhu-func-nav"):
 
179
  with gr.Column(scale=6):
180
  templateFileSelectDropdown = gr.Dropdown(
181
  label=i18n("选择Prompt模板集合文件"),
182
+ choices=get_template_names(),
183
  multiselect=False,
184
+ value=get_template_names()[0],
185
  container=False,
186
  )
187
  with gr.Column(scale=1):
 
191
  templateSelectDropdown = gr.Dropdown(
192
  label=i18n("从Prompt模板中加载"),
193
  choices=load_template(
194
+ get_template_names()[0], mode=1
195
  ),
196
  multiselect=False,
197
  container=False,
 
399
  current_model = get_model(model_name = MODELS[DEFAULT_MODEL], access_key = my_api_key)[0]
400
  current_model.set_user_identifier(user_name)
401
  chatbot = gr.Chatbot.update(label=MODELS[DEFAULT_MODEL])
402
+ return user_info, user_name, current_model, toggle_like_btn_visibility(DEFAULT_MODEL), *current_model.auto_load(), get_history_dropdown(user_name), chatbot
403
  demo.load(create_greeting, inputs=None, outputs=[user_info, user_name, current_model, like_dislike_area, systemPromptTxt, chatbot, historyFileSelectDropdown, chatbot], api_name="load")
404
  chatgpt_predict_args = dict(
405
  fn=predict,
 
446
  )
447
 
448
  refresh_history_args = dict(
449
+ fn=get_history_dropdown, inputs=[user_name], outputs=[historyFileSelectDropdown]
450
  )
451
 
452
 
 
524
 
525
  # Template
526
  systemPromptTxt.change(set_system_prompt, [current_model, systemPromptTxt], None)
527
+ templateRefreshBtn.click(get_template_dropdown, None, [templateFileSelectDropdown])
528
+ templateFileSelectDropdown.input(
529
  load_template,
530
  [templateFileSelectDropdown],
531
  [promptTemplates, templateSelectDropdown],
 
545
  downloadFile,
546
  show_progress=True,
547
  )
548
+ saveHistoryBtn.click(get_history_dropdown, [user_name], [historyFileSelectDropdown])
549
  exportMarkdownBtn.click(
550
  export_markdown,
551
  [current_model, saveFileName, chatbot, user_name],
config_example.json CHANGED
@@ -11,6 +11,9 @@
11
  "midjourney_proxy_api_secret": "", // 你的 MidJourney Proxy API Secret,用于鉴权访问 api,可选
12
  "midjourney_discord_proxy_url": "", // 你的 MidJourney Discord Proxy URL,用于对生成对图进行反代,可选
13
  "midjourney_temp_folder": "./tmp", // 你的 MidJourney 临时文件夹,用于存放生成的图片,填空则关闭自动下载切图(直接显示MJ的四宫格图)
 
 
 
14
 
15
 
16
  //== Azure ==
 
11
  "midjourney_proxy_api_secret": "", // 你的 MidJourney Proxy API Secret,用于鉴权访问 api,可选
12
  "midjourney_discord_proxy_url": "", // 你的 MidJourney Discord Proxy URL,用于对生成对图进行反代,可选
13
  "midjourney_temp_folder": "./tmp", // 你的 MidJourney 临时文件夹,用于存放生成的图片,填空则关闭自动下载切图(直接显示MJ的四宫格图)
14
+ "spark_appid": "", // 你的 讯飞星火大模型 API AppID,用于讯飞星火大模型对话模型
15
+ "spark_api_key": "", // 你的 讯飞星火大模型 API Key,用于讯飞星火大模型对话模型
16
+ "spark_api_secret": "", // 你的 讯飞星火大模型 API Secret,用于讯飞星火大模型对话模型
17
 
18
 
19
  //== Azure ==
modules/config.py CHANGED
@@ -123,6 +123,13 @@ os.environ["MIDJOURNEY_DISCORD_PROXY_URL"] = midjourney_discord_proxy_url
123
  midjourney_temp_folder = config.get("midjourney_temp_folder", "")
124
  os.environ["MIDJOURNEY_TEMP_FOLDER"] = midjourney_temp_folder
125
 
 
 
 
 
 
 
 
126
  load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
127
  "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
128
 
 
123
  midjourney_temp_folder = config.get("midjourney_temp_folder", "")
124
  os.environ["MIDJOURNEY_TEMP_FOLDER"] = midjourney_temp_folder
125
 
126
+ spark_api_key = config.get("spark_api_key", "")
127
+ os.environ["SPARK_API_KEY"] = spark_api_key
128
+ spark_appid = config.get("spark_appid", "")
129
+ os.environ["SPARK_APPID"] = spark_appid
130
+ spark_api_secret = config.get("spark_api_secret", "")
131
+ os.environ["SPARK_API_SECRET"] = spark_api_secret
132
+
133
  load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
134
  "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
135
 
modules/index_func.py CHANGED
@@ -23,6 +23,7 @@ def get_documents(file_src):
23
  filename = os.path.basename(filepath)
24
  file_type = os.path.splitext(filename)[1]
25
  logging.info(f"loading file: {filename}")
 
26
  try:
27
  if file_type == ".pdf":
28
  logging.debug("Loading PDF...")
@@ -72,8 +73,9 @@ def get_documents(file_src):
72
  logging.error(f"Error loading file: {filename}")
73
  traceback.print_exc()
74
 
75
- texts = text_splitter.split_documents(texts)
76
- documents.extend(texts)
 
77
  logging.debug("Documents loaded.")
78
  return documents
79
 
 
23
  filename = os.path.basename(filepath)
24
  file_type = os.path.splitext(filename)[1]
25
  logging.info(f"loading file: {filename}")
26
+ texts = None
27
  try:
28
  if file_type == ".pdf":
29
  logging.debug("Loading PDF...")
 
73
  logging.error(f"Error loading file: {filename}")
74
  traceback.print_exc()
75
 
76
+ if texts is not None:
77
+ texts = text_splitter.split_documents(texts)
78
+ documents.extend(texts)
79
  logging.debug("Documents loaded.")
80
  return documents
81
 
modules/models/base_model.py CHANGED
@@ -142,6 +142,7 @@ class ModelType(Enum):
142
  GooglePaLM = 9
143
  LangchainChat = 10
144
  Midjourney = 11
 
145
 
146
  @classmethod
147
  def get_type(cls, model_name: str):
@@ -171,6 +172,8 @@ class ModelType(Enum):
171
  model_type = ModelType.Midjourney
172
  elif "azure" in model_name_lower or "api" in model_name_lower:
173
  model_type = ModelType.LangchainChat
 
 
174
  else:
175
  model_type = ModelType.Unknown
176
  return model_type
@@ -269,9 +272,12 @@ class BaseLLMModel:
269
  if display_append:
270
  display_append = '\n\n<hr class="append-display no-in-raw" />' + display_append
271
  partial_text = ""
 
272
  for partial_text in stream_iter:
 
 
273
  chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
274
- self.all_token_counts[-1] += 1
275
  status_text = self.token_message()
276
  yield get_return_value()
277
  if self.interrupted:
@@ -718,7 +724,7 @@ class BaseLLMModel:
718
  history_file_path = filename
719
  try:
720
  os.remove(history_file_path)
721
- return i18n("删除对话历史成功"), get_history_names(False, user_name), []
722
  except:
723
  logging.info(f"删除对话历史失败 {history_file_path}")
724
  return i18n("对话历史")+filename+i18n("已经被删除啦"), gr.update(), gr.update()
 
142
  GooglePaLM = 9
143
  LangchainChat = 10
144
  Midjourney = 11
145
+ Spark = 12
146
 
147
  @classmethod
148
  def get_type(cls, model_name: str):
 
172
  model_type = ModelType.Midjourney
173
  elif "azure" in model_name_lower or "api" in model_name_lower:
174
  model_type = ModelType.LangchainChat
175
+ elif "星火大模型" in model_name_lower:
176
+ model_type = ModelType.Spark
177
  else:
178
  model_type = ModelType.Unknown
179
  return model_type
 
272
  if display_append:
273
  display_append = '\n\n<hr class="append-display no-in-raw" />' + display_append
274
  partial_text = ""
275
+ token_increment = 1
276
  for partial_text in stream_iter:
277
+ if type(partial_text) == tuple:
278
+ partial_text, token_increment = partial_text
279
  chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
280
+ self.all_token_counts[-1] += token_increment
281
  status_text = self.token_message()
282
  yield get_return_value()
283
  if self.interrupted:
 
724
  history_file_path = filename
725
  try:
726
  os.remove(history_file_path)
727
+ return i18n("删除对话历史成功"), get_history_dropdown(user_name), []
728
  except:
729
  logging.info(f"删除对话历史失败 {history_file_path}")
730
  return i18n("对话历史")+filename+i18n("已经被删除啦"), gr.update(), gr.update()
modules/models/models.py CHANGED
@@ -580,8 +580,8 @@ def get_model(
580
  logging.info(msg)
581
  lora_selector_visibility = True
582
  if os.path.isdir("lora"):
583
- lora_choices = get_file_names(
584
- "lora", plain=True, filetypes=[""])
585
  lora_choices = ["No LoRA"] + lora_choices
586
  elif model_type == ModelType.LLaMA and lora_model_path != "":
587
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
@@ -625,6 +625,9 @@ def get_model(
625
  from .midjourney import Midjourney_Client
626
  mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
627
  model = Midjourney_Client(model_name, mj_proxy_api_secret, user_name=user_name)
 
 
 
628
  elif model_type == ModelType.Unknown:
629
  raise ValueError(f"未知模型: {model_name}")
630
  logging.info(msg)
 
580
  logging.info(msg)
581
  lora_selector_visibility = True
582
  if os.path.isdir("lora"):
583
+ lora_choices = get_file_names_dropdown_by_pinyin(
584
+ "lora", filetypes=[""])
585
  lora_choices = ["No LoRA"] + lora_choices
586
  elif model_type == ModelType.LLaMA and lora_model_path != "":
587
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
 
625
  from .midjourney import Midjourney_Client
626
  mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
627
  model = Midjourney_Client(model_name, mj_proxy_api_secret, user_name=user_name)
628
+ elif model_type == ModelType.Spark:
629
+ from .spark import Spark_Client
630
+ model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv("SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
631
  elif model_type == ModelType.Unknown:
632
  raise ValueError(f"未知模型: {model_name}")
633
  logging.info(msg)
modules/models/spark.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import _thread as thread
2
+ import base64
3
+ import datetime
4
+ import hashlib
5
+ import hmac
6
+ import json
7
+ from collections import deque
8
+ from urllib.parse import urlparse
9
+ import ssl
10
+ from datetime import datetime
11
+ from time import mktime
12
+ from urllib.parse import urlencode
13
+ from wsgiref.handlers import format_date_time
14
+ from threading import Condition
15
+ import websocket
16
+ import logging
17
+
18
+ from .base_model import BaseLLMModel, CallbackToIterator
19
+
20
+
21
+ class Ws_Param(object):
22
+ # 来自官方 Demo
23
+ # 初始化
24
+ def __init__(self, APPID, APIKey, APISecret, Spark_url):
25
+ self.APPID = APPID
26
+ self.APIKey = APIKey
27
+ self.APISecret = APISecret
28
+ self.host = urlparse(Spark_url).netloc
29
+ self.path = urlparse(Spark_url).path
30
+ self.Spark_url = Spark_url
31
+
32
+ # 生成url
33
+ def create_url(self):
34
+ # 生成RFC1123格式的时间戳
35
+ now = datetime.now()
36
+ date = format_date_time(mktime(now.timetuple()))
37
+
38
+ # 拼接字符串
39
+ signature_origin = "host: " + self.host + "\n"
40
+ signature_origin += "date: " + date + "\n"
41
+ signature_origin += "GET " + self.path + " HTTP/1.1"
42
+
43
+ # 进行hmac-sha256进行加密
44
+ signature_sha = hmac.new(
45
+ self.APISecret.encode("utf-8"),
46
+ signature_origin.encode("utf-8"),
47
+ digestmod=hashlib.sha256,
48
+ ).digest()
49
+
50
+ signature_sha_base64 = base64.b64encode(
51
+ signature_sha).decode(encoding="utf-8")
52
+
53
+ authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
54
+
55
+ authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
56
+ encoding="utf-8"
57
+ )
58
+
59
+ # 将请求的鉴权参数组合为字典
60
+ v = {"authorization": authorization, "date": date, "host": self.host}
61
+ # 拼接鉴权参数,生成url
62
+ url = self.Spark_url + "?" + urlencode(v)
63
+ # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
64
+ return url
65
+
66
+
67
+ class Spark_Client(BaseLLMModel):
68
+ def __init__(self, model_name, appid, api_key, api_secret, user_name="") -> None:
69
+ super().__init__(model_name=model_name, user=user_name)
70
+ self.api_key = api_key
71
+ self.appid = appid
72
+ self.api_secret = api_secret
73
+ if None in [self.api_key, self.appid, self.api_secret]:
74
+ raise Exception("请在配置文件或者环境变量中设置讯飞的API Key、APP ID和API Secret")
75
+ if "2.0" in self.model_name:
76
+ self.spark_url = "wss://spark-api.xf-yun.com/v2.1/chat"
77
+ self.domain = "generalv2"
78
+ else:
79
+ self.spark_url = "wss://spark-api.xf-yun.com/v1.1/chat"
80
+ self.domain = "general"
81
+
82
+ # 收到websocket错误的处理
83
+ def on_error(self, ws, error):
84
+ ws.iterator.callback("出现了错误:" + error)
85
+
86
+ # 收到websocket关闭的处理
87
+ def on_close(self, ws, one, two):
88
+ pass
89
+
90
+ # 收到websocket连接建立的处理
91
+ def on_open(self, ws):
92
+ thread.start_new_thread(self.run, (ws,))
93
+
94
+ def run(self, ws, *args):
95
+ data = json.dumps(
96
+ self.gen_params()
97
+ )
98
+ ws.send(data)
99
+
100
+ # 收到websocket消息的处理
101
+ def on_message(self, ws, message):
102
+ ws.iterator.callback(message)
103
+
104
+ def gen_params(self):
105
+ """
106
+ 通过appid和用户的提问来生成请参数
107
+ """
108
+ data = {
109
+ "header": {"app_id": self.appid, "uid": "1234"},
110
+ "parameter": {
111
+ "chat": {
112
+ "domain": self.domain,
113
+ "random_threshold": self.temperature,
114
+ "max_tokens": 4096,
115
+ "auditing": "default",
116
+ }
117
+ },
118
+ "payload": {"message": {"text": self.history}},
119
+ }
120
+ return data
121
+
122
+ def get_answer_stream_iter(self):
123
+ wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, self.spark_url)
124
+ websocket.enableTrace(False)
125
+ wsUrl = wsParam.create_url()
126
+ ws = websocket.WebSocketApp(
127
+ wsUrl,
128
+ on_message=self.on_message,
129
+ on_error=self.on_error,
130
+ on_close=self.on_close,
131
+ on_open=self.on_open,
132
+ )
133
+ ws.appid = self.appid
134
+ ws.domain = self.domain
135
+
136
+ # Initialize the CallbackToIterator
137
+ ws.iterator = CallbackToIterator()
138
+
139
+ # Start the WebSocket connection in a separate thread
140
+ thread.start_new_thread(
141
+ ws.run_forever, (), {"sslopt": {"cert_reqs": ssl.CERT_NONE}}
142
+ )
143
+
144
+ # Iterate over the CallbackToIterator instance
145
+ answer = ""
146
+ total_tokens = 0
147
+ for message in ws.iterator:
148
+ data = json.loads(message)
149
+ code = data["header"]["code"]
150
+ if code != 0:
151
+ ws.close()
152
+ raise Exception(f"请求错误: {code}, {data}")
153
+ else:
154
+ choices = data["payload"]["choices"]
155
+ status = choices["status"]
156
+ content = choices["text"][0]["content"]
157
+ if "usage" in data["payload"]:
158
+ total_tokens = data["payload"]["usage"]["text"]["total_tokens"]
159
+ answer += content
160
+ if status == 2:
161
+ ws.iterator.finish() # Finish the iterator when the status is 2
162
+ ws.close()
163
+ yield answer, total_tokens
modules/presets.py CHANGED
@@ -69,7 +69,9 @@ ONLINE_MODELS = [
69
  "yuanai-1.0-rhythm_poems",
70
  "minimax-abab4-chat",
71
  "minimax-abab5-chat",
72
- "midjourney"
 
 
73
  ]
74
 
75
  LOCAL_MODELS = [
 
69
  "yuanai-1.0-rhythm_poems",
70
  "minimax-abab4-chat",
71
  "minimax-abab5-chat",
72
+ "midjourney",
73
+ "讯飞星火大模型V2.0",
74
+ "讯飞星火大模型V1.5"
75
  ]
76
 
77
  LOCAL_MODELS = [
modules/utils.py CHANGED
@@ -354,31 +354,50 @@ def save_file(filename, system, history, chatbot, user_name):
354
  def sorted_by_pinyin(list):
355
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
356
 
 
 
357
 
358
- def get_file_names(dir, plain=False, filetypes=[".json"]):
359
- logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
360
  files = []
361
  try:
362
  for type in filetypes:
363
  files += [f for f in os.listdir(dir) if f.endswith(type)]
364
  except FileNotFoundError:
365
- files = []
366
- files = sorted_by_pinyin(files)
367
- if files == []:
368
  files = [""]
369
  logging.debug(f"files are:{files}")
370
- if plain:
371
- return files
372
- else:
373
- return gr.Dropdown.update(choices=files)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
 
376
- def get_history_names(plain=False, user_name=""):
377
  logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表")
378
  if user_name == "" and hide_history_when_not_logged_in:
379
  return ""
380
  else:
381
- return get_file_names(os.path.join(HISTORY_DIR, user_name), plain)
 
 
 
 
 
382
 
383
 
384
  def load_template(filename, mode=0):
@@ -406,9 +425,14 @@ def load_template(filename, mode=0):
406
  )
407
 
408
 
409
- def get_template_names(plain=False):
410
  logging.debug("获取模板文件名列表")
411
- return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
 
 
 
 
 
412
 
413
 
414
  def get_template_content(templates, selection, original_system_prompt):
 
354
  def sorted_by_pinyin(list):
355
  return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
356
 
357
+ def sorted_by_last_modified_time(list, dir):
358
+ return sorted(list, key=lambda char: os.path.getmtime(os.path.join(dir, char)), reverse=True)
359
 
360
+ def get_file_names_by_type(dir, filetypes=[".json"]):
361
+ logging.debug(f"获取文件名列表,目录为{dir},文件类型为{filetypes}")
362
  files = []
363
  try:
364
  for type in filetypes:
365
  files += [f for f in os.listdir(dir) if f.endswith(type)]
366
  except FileNotFoundError:
 
 
 
367
  files = [""]
368
  logging.debug(f"files are:{files}")
369
+ return files
370
+
371
+ def get_file_names_by_pinyin(dir, filetypes=[".json"]):
372
+ files = get_file_names_by_type(dir, filetypes)
373
+ if files != [""]:
374
+ files = sorted_by_pinyin(files)
375
+ logging.debug(f"files are:{files}")
376
+ return files
377
+
378
+ def get_file_names_dropdown_by_pinyin(dir, filetypes=[".json"]):
379
+ files = get_file_names_by_pinyin(dir, filetypes)
380
+ return gr.Dropdown.update(choices=files)
381
+
382
+ def get_file_names_by_last_modified_time(dir, filetypes=[".json"]):
383
+ files = get_file_names_by_type(dir, filetypes)
384
+ if files != [""]:
385
+ files = sorted_by_last_modified_time(files, dir)
386
+ logging.debug(f"files are:{files}")
387
+ return files
388
 
389
 
390
+ def get_history_names(user_name=""):
391
  logging.debug(f"从用户 {user_name} 中获取历史记录文件名列表")
392
  if user_name == "" and hide_history_when_not_logged_in:
393
  return ""
394
  else:
395
+ history_files = get_file_names_by_last_modified_time(os.path.join(HISTORY_DIR, user_name))
396
+ return history_files
397
+
398
+ def get_history_dropdown(user_name=""):
399
+ history_names = get_history_names(user_name)
400
+ return gr.Dropdown.update(choices=history_names)
401
 
402
 
403
  def load_template(filename, mode=0):
 
425
  )
426
 
427
 
428
+ def get_template_names():
429
  logging.debug("获取模板文件名列表")
430
+ return get_file_names_by_pinyin(TEMPLATES_DIR, filetypes=[".csv", "json"])
431
+
432
+ def get_template_dropdown():
433
+ logging.debug("获取模板下拉菜单")
434
+ template_names = get_template_names()
435
+ return gr.Dropdown.update(choices=template_names)
436
 
437
 
438
  def get_template_content(templates, selection, original_system_prompt):
requirements.txt CHANGED
@@ -26,3 +26,5 @@ unstructured
26
  google-api-python-client
27
  tabulate
28
  ujson
 
 
 
26
  google-api-python-client
27
  tabulate
28
  ujson
29
+ python-docx
30
+ websocket_client