Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
•
8728d12
1
Parent(s):
3675c9f
bugfix: 加入LLaMA.cpp
Browse files- modules/models/OpenAI.py +14 -13
- modules/models/XMChat.py +2 -2
- modules/models/midjourney.py +3 -4
- modules/models/models.py +2 -2
modules/models/OpenAI.py
CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2 |
|
3 |
import json
|
4 |
import logging
|
|
|
5 |
|
6 |
import colorama
|
7 |
import requests
|
@@ -85,11 +86,11 @@ class OpenAIClient(BaseLLMModel):
|
|
85 |
|
86 |
# return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
|
87 |
return get_html("billing_info.html").format(
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
except requests.exceptions.ConnectTimeout:
|
94 |
status_text = (
|
95 |
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
|
@@ -161,6 +162,7 @@ class OpenAIClient(BaseLLMModel):
|
|
161 |
timeout=timeout,
|
162 |
)
|
163 |
except:
|
|
|
164 |
return None
|
165 |
return response
|
166 |
|
@@ -170,6 +172,7 @@ class OpenAIClient(BaseLLMModel):
|
|
170 |
"Authorization": f"Bearer {sensitive_id}",
|
171 |
}
|
172 |
|
|
|
173 |
def _get_billing_data(self, billing_url):
|
174 |
with retrieve_proxy():
|
175 |
response = requests.get(
|
@@ -240,6 +243,7 @@ class OpenAIClient(BaseLLMModel):
|
|
240 |
|
241 |
return response
|
242 |
|
|
|
243 |
def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
|
244 |
if len(self.history) == 2 and not single_turn_checkbox:
|
245 |
user_question = self.history[0]["content"]
|
@@ -247,22 +251,19 @@ class OpenAIClient(BaseLLMModel):
|
|
247 |
ai_answer = self.history[1]["content"]
|
248 |
try:
|
249 |
history = [
|
250 |
-
{"role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
|
251 |
-
{"role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
|
252 |
]
|
253 |
-
response = self._single_query_at_once(
|
254 |
-
history, temperature=0.0)
|
255 |
response = json.loads(response.text)
|
256 |
content = response["choices"][0]["message"]["content"]
|
257 |
filename = replace_special_symbols(content) + ".json"
|
258 |
except Exception as e:
|
259 |
logging.info(f"自动命名失败。{e}")
|
260 |
-
filename = replace_special_symbols(user_question)[
|
261 |
-
:16] + ".json"
|
262 |
return self.rename_chat_history(filename, chatbot, user_name)
|
263 |
elif name_chat_method == i18n("第一条提问"):
|
264 |
-
filename = replace_special_symbols(user_question)[
|
265 |
-
:16] + ".json"
|
266 |
return self.rename_chat_history(filename, chatbot, user_name)
|
267 |
else:
|
268 |
return gr.update()
|
|
|
2 |
|
3 |
import json
|
4 |
import logging
|
5 |
+
import traceback
|
6 |
|
7 |
import colorama
|
8 |
import requests
|
|
|
86 |
|
87 |
# return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
|
88 |
return get_html("billing_info.html").format(
|
89 |
+
label = i18n("本月使用金额"),
|
90 |
+
usage_percent = usage_percent,
|
91 |
+
rounded_usage = rounded_usage,
|
92 |
+
usage_limit = usage_limit
|
93 |
+
)
|
94 |
except requests.exceptions.ConnectTimeout:
|
95 |
status_text = (
|
96 |
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
|
|
|
162 |
timeout=timeout,
|
163 |
)
|
164 |
except:
|
165 |
+
traceback.print_exc()
|
166 |
return None
|
167 |
return response
|
168 |
|
|
|
172 |
"Authorization": f"Bearer {sensitive_id}",
|
173 |
}
|
174 |
|
175 |
+
|
176 |
def _get_billing_data(self, billing_url):
|
177 |
with retrieve_proxy():
|
178 |
response = requests.get(
|
|
|
243 |
|
244 |
return response
|
245 |
|
246 |
+
|
247 |
def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
|
248 |
if len(self.history) == 2 and not single_turn_checkbox:
|
249 |
user_question = self.history[0]["content"]
|
|
|
251 |
ai_answer = self.history[1]["content"]
|
252 |
try:
|
253 |
history = [
|
254 |
+
{ "role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
|
255 |
+
{ "role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
|
256 |
]
|
257 |
+
response = self._single_query_at_once(history, temperature=0.0)
|
|
|
258 |
response = json.loads(response.text)
|
259 |
content = response["choices"][0]["message"]["content"]
|
260 |
filename = replace_special_symbols(content) + ".json"
|
261 |
except Exception as e:
|
262 |
logging.info(f"自动命名失败。{e}")
|
263 |
+
filename = replace_special_symbols(user_question)[:16] + ".json"
|
|
|
264 |
return self.rename_chat_history(filename, chatbot, user_name)
|
265 |
elif name_chat_method == i18n("第一条提问"):
|
266 |
+
filename = replace_special_symbols(user_question)[:16] + ".json"
|
|
|
267 |
return self.rename_chat_history(filename, chatbot, user_name)
|
268 |
else:
|
269 |
return gr.update()
|
modules/models/XMChat.py
CHANGED
@@ -16,7 +16,7 @@ from ..utils import *
|
|
16 |
from .base_model import BaseLLMModel
|
17 |
|
18 |
|
19 |
-
class
|
20 |
def __init__(self, api_key, user_name=""):
|
21 |
super().__init__(model_name="xmchat", user=user_name)
|
22 |
self.api_key = api_key
|
@@ -31,7 +31,7 @@ class XMChatClient(BaseLLMModel):
|
|
31 |
def reset(self):
|
32 |
self.session_id = str(uuid.uuid4())
|
33 |
self.last_conv_id = None
|
34 |
-
return
|
35 |
|
36 |
def image_to_base64(self, image_path):
|
37 |
# 打开并加载图片
|
|
|
16 |
from .base_model import BaseLLMModel
|
17 |
|
18 |
|
19 |
+
class XMChat(BaseLLMModel):
|
20 |
def __init__(self, api_key, user_name=""):
|
21 |
super().__init__(model_name="xmchat", user=user_name)
|
22 |
self.api_key = api_key
|
|
|
31 |
def reset(self):
|
32 |
self.session_id = str(uuid.uuid4())
|
33 |
self.last_conv_id = None
|
34 |
+
return super().reset()
|
35 |
|
36 |
def image_to_base64(self, image_path):
|
37 |
# 打开并加载图片
|
modules/models/midjourney.py
CHANGED
@@ -2,11 +2,10 @@ import base64
|
|
2 |
import io
|
3 |
import json
|
4 |
import logging
|
|
|
5 |
import pathlib
|
6 |
-
import time
|
7 |
import tempfile
|
8 |
-
import
|
9 |
-
|
10 |
from datetime import datetime
|
11 |
|
12 |
import requests
|
@@ -14,7 +13,7 @@ import tiktoken
|
|
14 |
from PIL import Image
|
15 |
|
16 |
from modules.config import retrieve_proxy
|
17 |
-
from modules.models.
|
18 |
|
19 |
mj_proxy_api_base = os.getenv("MIDJOURNEY_PROXY_API_BASE")
|
20 |
mj_discord_proxy_url = os.getenv("MIDJOURNEY_DISCORD_PROXY_URL")
|
|
|
2 |
import io
|
3 |
import json
|
4 |
import logging
|
5 |
+
import os
|
6 |
import pathlib
|
|
|
7 |
import tempfile
|
8 |
+
import time
|
|
|
9 |
from datetime import datetime
|
10 |
|
11 |
import requests
|
|
|
13 |
from PIL import Image
|
14 |
|
15 |
from modules.config import retrieve_proxy
|
16 |
+
from modules.models.XMChat import XMChat
|
17 |
|
18 |
mj_proxy_api_base = os.getenv("MIDJOURNEY_PROXY_API_BASE")
|
19 |
mj_discord_proxy_url = os.getenv("MIDJOURNEY_DISCORD_PROXY_URL")
|
modules/models/models.py
CHANGED
@@ -69,10 +69,10 @@ def get_model(
|
|
69 |
model = LLaMA_Client(
|
70 |
model_name, lora_model_path, user_name=user_name)
|
71 |
elif model_type == ModelType.XMChat:
|
72 |
-
from .XMChat import
|
73 |
if os.environ.get("XMCHAT_API_KEY") != "":
|
74 |
access_key = os.environ.get("XMCHAT_API_KEY")
|
75 |
-
model =
|
76 |
elif model_type == ModelType.StableLM:
|
77 |
from .StableLM import StableLM_Client
|
78 |
model = StableLM_Client(model_name, user_name=user_name)
|
|
|
69 |
model = LLaMA_Client(
|
70 |
model_name, lora_model_path, user_name=user_name)
|
71 |
elif model_type == ModelType.XMChat:
|
72 |
+
from .XMChat import XMChat
|
73 |
if os.environ.get("XMCHAT_API_KEY") != "":
|
74 |
access_key = os.environ.get("XMCHAT_API_KEY")
|
75 |
+
model = XMChat(api_key=access_key, user_name=user_name)
|
76 |
elif model_type == ModelType.StableLM:
|
77 |
from .StableLM import StableLM_Client
|
78 |
model = StableLM_Client(model_name, user_name=user_name)
|