Tuchuanhuhuhu commited on
Commit
8728d12
1 Parent(s): 3675c9f

bugfix: 加入LLaMA.cpp

Browse files
modules/models/OpenAI.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  import json
4
  import logging
 
5
 
6
  import colorama
7
  import requests
@@ -85,11 +86,11 @@ class OpenAIClient(BaseLLMModel):
85
 
86
  # return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
87
  return get_html("billing_info.html").format(
88
- label=i18n("本月使用金额"),
89
- usage_percent=usage_percent,
90
- rounded_usage=rounded_usage,
91
- usage_limit=usage_limit
92
- )
93
  except requests.exceptions.ConnectTimeout:
94
  status_text = (
95
  STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
@@ -161,6 +162,7 @@ class OpenAIClient(BaseLLMModel):
161
  timeout=timeout,
162
  )
163
  except:
 
164
  return None
165
  return response
166
 
@@ -170,6 +172,7 @@ class OpenAIClient(BaseLLMModel):
170
  "Authorization": f"Bearer {sensitive_id}",
171
  }
172
 
 
173
  def _get_billing_data(self, billing_url):
174
  with retrieve_proxy():
175
  response = requests.get(
@@ -240,6 +243,7 @@ class OpenAIClient(BaseLLMModel):
240
 
241
  return response
242
 
 
243
  def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
244
  if len(self.history) == 2 and not single_turn_checkbox:
245
  user_question = self.history[0]["content"]
@@ -247,22 +251,19 @@ class OpenAIClient(BaseLLMModel):
247
  ai_answer = self.history[1]["content"]
248
  try:
249
  history = [
250
- {"role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
251
- {"role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
252
  ]
253
- response = self._single_query_at_once(
254
- history, temperature=0.0)
255
  response = json.loads(response.text)
256
  content = response["choices"][0]["message"]["content"]
257
  filename = replace_special_symbols(content) + ".json"
258
  except Exception as e:
259
  logging.info(f"自动命名失败。{e}")
260
- filename = replace_special_symbols(user_question)[
261
- :16] + ".json"
262
  return self.rename_chat_history(filename, chatbot, user_name)
263
  elif name_chat_method == i18n("第一条提问"):
264
- filename = replace_special_symbols(user_question)[
265
- :16] + ".json"
266
  return self.rename_chat_history(filename, chatbot, user_name)
267
  else:
268
  return gr.update()
 
2
 
3
  import json
4
  import logging
5
+ import traceback
6
 
7
  import colorama
8
  import requests
 
86
 
87
  # return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
88
  return get_html("billing_info.html").format(
89
+ label = i18n("本月使用金额"),
90
+ usage_percent = usage_percent,
91
+ rounded_usage = rounded_usage,
92
+ usage_limit = usage_limit
93
+ )
94
  except requests.exceptions.ConnectTimeout:
95
  status_text = (
96
  STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
 
162
  timeout=timeout,
163
  )
164
  except:
165
+ traceback.print_exc()
166
  return None
167
  return response
168
 
 
172
  "Authorization": f"Bearer {sensitive_id}",
173
  }
174
 
175
+
176
  def _get_billing_data(self, billing_url):
177
  with retrieve_proxy():
178
  response = requests.get(
 
243
 
244
  return response
245
 
246
+
247
  def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
248
  if len(self.history) == 2 and not single_turn_checkbox:
249
  user_question = self.history[0]["content"]
 
251
  ai_answer = self.history[1]["content"]
252
  try:
253
  history = [
254
+ { "role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
255
+ { "role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
256
  ]
257
+ response = self._single_query_at_once(history, temperature=0.0)
 
258
  response = json.loads(response.text)
259
  content = response["choices"][0]["message"]["content"]
260
  filename = replace_special_symbols(content) + ".json"
261
  except Exception as e:
262
  logging.info(f"自动命名失败。{e}")
263
+ filename = replace_special_symbols(user_question)[:16] + ".json"
 
264
  return self.rename_chat_history(filename, chatbot, user_name)
265
  elif name_chat_method == i18n("第一条提问"):
266
+ filename = replace_special_symbols(user_question)[:16] + ".json"
 
267
  return self.rename_chat_history(filename, chatbot, user_name)
268
  else:
269
  return gr.update()
modules/models/XMChat.py CHANGED
@@ -16,7 +16,7 @@ from ..utils import *
16
  from .base_model import BaseLLMModel
17
 
18
 
19
- class XMChatClient(BaseLLMModel):
20
  def __init__(self, api_key, user_name=""):
21
  super().__init__(model_name="xmchat", user=user_name)
22
  self.api_key = api_key
@@ -31,7 +31,7 @@ class XMChatClient(BaseLLMModel):
31
  def reset(self):
32
  self.session_id = str(uuid.uuid4())
33
  self.last_conv_id = None
34
- return [], "已重置"
35
 
36
  def image_to_base64(self, image_path):
37
  # 打开并加载图片
 
16
  from .base_model import BaseLLMModel
17
 
18
 
19
+ class XMChat(BaseLLMModel):
20
  def __init__(self, api_key, user_name=""):
21
  super().__init__(model_name="xmchat", user=user_name)
22
  self.api_key = api_key
 
31
  def reset(self):
32
  self.session_id = str(uuid.uuid4())
33
  self.last_conv_id = None
34
+ return super().reset()
35
 
36
  def image_to_base64(self, image_path):
37
  # 打开并加载图片
modules/models/midjourney.py CHANGED
@@ -2,11 +2,10 @@ import base64
2
  import io
3
  import json
4
  import logging
 
5
  import pathlib
6
- import time
7
  import tempfile
8
- import os
9
-
10
  from datetime import datetime
11
 
12
  import requests
@@ -14,7 +13,7 @@ import tiktoken
14
  from PIL import Image
15
 
16
  from modules.config import retrieve_proxy
17
- from modules.models.models import XMChat
18
 
19
  mj_proxy_api_base = os.getenv("MIDJOURNEY_PROXY_API_BASE")
20
  mj_discord_proxy_url = os.getenv("MIDJOURNEY_DISCORD_PROXY_URL")
 
2
  import io
3
  import json
4
  import logging
5
+ import os
6
  import pathlib
 
7
  import tempfile
8
+ import time
 
9
  from datetime import datetime
10
 
11
  import requests
 
13
  from PIL import Image
14
 
15
  from modules.config import retrieve_proxy
16
+ from modules.models.XMChat import XMChat
17
 
18
  mj_proxy_api_base = os.getenv("MIDJOURNEY_PROXY_API_BASE")
19
  mj_discord_proxy_url = os.getenv("MIDJOURNEY_DISCORD_PROXY_URL")
modules/models/models.py CHANGED
@@ -69,10 +69,10 @@ def get_model(
69
  model = LLaMA_Client(
70
  model_name, lora_model_path, user_name=user_name)
71
  elif model_type == ModelType.XMChat:
72
- from .XMChat import XMChatClient
73
  if os.environ.get("XMCHAT_API_KEY") != "":
74
  access_key = os.environ.get("XMCHAT_API_KEY")
75
- model = XMChatClient(api_key=access_key, user_name=user_name)
76
  elif model_type == ModelType.StableLM:
77
  from .StableLM import StableLM_Client
78
  model = StableLM_Client(model_name, user_name=user_name)
 
69
  model = LLaMA_Client(
70
  model_name, lora_model_path, user_name=user_name)
71
  elif model_type == ModelType.XMChat:
72
+ from .XMChat import XMChat
73
  if os.environ.get("XMCHAT_API_KEY") != "":
74
  access_key = os.environ.get("XMCHAT_API_KEY")
75
+ model = XMChat(api_key=access_key, user_name=user_name)
76
  elif model_type == ModelType.StableLM:
77
  from .StableLM import StableLM_Client
78
  model = StableLM_Client(model_name, user_name=user_name)