Spaces:
Sleeping
Sleeping
Feliciano Long
commited on
Commit
•
cef64b2
1
Parent(s):
0aaa1a4
feat: Add Minimax model (#774)
Browse files* add minimax model and implement both get answer at once/stream, also the role play setting and token count
* add configurable minimax api_key and group id, add explaination to config example
- config_example.json +3 -0
- modules/config.py +5 -0
- modules/models/base_model.py +3 -0
- modules/models/minimax.py +161 -0
- modules/models/models.py +5 -0
- modules/presets.py +2 -0
config_example.json
CHANGED
@@ -5,6 +5,9 @@
|
|
5 |
"usage_limit": 120, // API Key的当月限额,单位:美元
|
6 |
// 你的xmchat API Key,与OpenAI API Key不同
|
7 |
"xmchat_api_key": "",
|
|
|
|
|
|
|
8 |
"language": "auto",
|
9 |
// 如果使用代理,请取消注释下面的两行,并替换代理URL
|
10 |
// "https_proxy": "http://127.0.0.1:1079",
|
|
|
5 |
"usage_limit": 120, // API Key的当月限额,单位:美元
|
6 |
// 你的xmchat API Key,与OpenAI API Key不同
|
7 |
"xmchat_api_key": "",
|
8 |
+
// MiniMax的APIKey(见账户管理页面 https://api.minimax.chat/basic-information)和Group ID,用于MiniMax对话模型
|
9 |
+
"minimax_api_key": "",
|
10 |
+
"minimax_group_id": "",
|
11 |
"language": "auto",
|
12 |
// 如果使用代理,请取消注释下面的两行,并替换代理URL
|
13 |
// "https_proxy": "http://127.0.0.1:1079",
|
modules/config.py
CHANGED
@@ -76,6 +76,11 @@ my_api_key = os.environ.get("OPENAI_API_KEY", my_api_key)
|
|
76 |
xmchat_api_key = config.get("xmchat_api_key", "")
|
77 |
os.environ["XMCHAT_API_KEY"] = xmchat_api_key
|
78 |
|
|
|
|
|
|
|
|
|
|
|
79 |
render_latex = config.get("render_latex", True)
|
80 |
|
81 |
if render_latex:
|
|
|
76 |
xmchat_api_key = config.get("xmchat_api_key", "")
|
77 |
os.environ["XMCHAT_API_KEY"] = xmchat_api_key
|
78 |
|
79 |
+
minimax_api_key = config.get("minimax_api_key", "")
|
80 |
+
os.environ["MINIMAX_API_KEY"] = minimax_api_key
|
81 |
+
minimax_group_id = config.get("minimax_group_id", "")
|
82 |
+
os.environ["MINIMAX_GROUP_ID"] = minimax_group_id
|
83 |
+
|
84 |
render_latex = config.get("render_latex", True)
|
85 |
|
86 |
if render_latex:
|
modules/models/base_model.py
CHANGED
@@ -34,6 +34,7 @@ class ModelType(Enum):
|
|
34 |
StableLM = 4
|
35 |
MOSS = 5
|
36 |
YuanAI = 6
|
|
|
37 |
|
38 |
@classmethod
|
39 |
def get_type(cls, model_name: str):
|
@@ -53,6 +54,8 @@ class ModelType(Enum):
|
|
53 |
model_type = ModelType.MOSS
|
54 |
elif "yuanai" in model_name_lower:
|
55 |
model_type = ModelType.YuanAI
|
|
|
|
|
56 |
else:
|
57 |
model_type = ModelType.Unknown
|
58 |
return model_type
|
|
|
34 |
StableLM = 4
|
35 |
MOSS = 5
|
36 |
YuanAI = 6
|
37 |
+
Minimax = 7
|
38 |
|
39 |
@classmethod
|
40 |
def get_type(cls, model_name: str):
|
|
|
54 |
model_type = ModelType.MOSS
|
55 |
elif "yuanai" in model_name_lower:
|
56 |
model_type = ModelType.YuanAI
|
57 |
+
elif "minimax" in model_name_lower:
|
58 |
+
model_type = ModelType.Minimax
|
59 |
else:
|
60 |
model_type = ModelType.Unknown
|
61 |
return model_type
|
modules/models/minimax.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import colorama
|
5 |
+
import requests
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from modules.models.base_model import BaseLLMModel
|
9 |
+
from modules.presets import STANDARD_ERROR_MSG, GENERAL_ERROR_MSG, TIMEOUT_STREAMING, TIMEOUT_ALL, i18n
|
10 |
+
|
11 |
+
group_id = os.environ.get("MINIMAX_GROUP_ID", "")
|
12 |
+
|
13 |
+
|
14 |
+
class MiniMax_Client(BaseLLMModel):
|
15 |
+
"""
|
16 |
+
MiniMax Client
|
17 |
+
接口文档见 https://api.minimax.chat/document/guides/chat
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, model_name, api_key, user_name="", system_prompt=None):
|
21 |
+
super().__init__(model_name=model_name, user=user_name)
|
22 |
+
self.url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}'
|
23 |
+
self.history = []
|
24 |
+
self.api_key = api_key
|
25 |
+
self.system_prompt = system_prompt
|
26 |
+
self.headers = {
|
27 |
+
"Authorization": f"Bearer {api_key}",
|
28 |
+
"Content-Type": "application/json"
|
29 |
+
}
|
30 |
+
|
31 |
+
def get_answer_at_once(self):
|
32 |
+
# minimax temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert
|
33 |
+
temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10
|
34 |
+
|
35 |
+
request_body = {
|
36 |
+
"model": self.model_name.replace('minimax-', ''),
|
37 |
+
"temperature": temperature,
|
38 |
+
"skip_info_mask": True,
|
39 |
+
'messages': [{"sender_type": "USER", "text": self.history[-1]['content']}]
|
40 |
+
}
|
41 |
+
if self.n_choices:
|
42 |
+
request_body['beam_width'] = self.n_choices
|
43 |
+
if self.system_prompt:
|
44 |
+
request_body['prompt'] = self.system_prompt
|
45 |
+
if self.max_generation_token:
|
46 |
+
request_body['tokens_to_generate'] = self.max_generation_token
|
47 |
+
if self.top_p:
|
48 |
+
request_body['top_p'] = self.top_p
|
49 |
+
|
50 |
+
response = requests.post(self.url, headers=self.headers, json=request_body)
|
51 |
+
|
52 |
+
res = response.json()
|
53 |
+
answer = res['reply']
|
54 |
+
total_token_count = res["usage"]["total_tokens"]
|
55 |
+
return answer, total_token_count
|
56 |
+
|
57 |
+
def get_answer_stream_iter(self):
|
58 |
+
response = self._get_response(stream=True)
|
59 |
+
if response is not None:
|
60 |
+
iter = self._decode_chat_response(response)
|
61 |
+
partial_text = ""
|
62 |
+
for i in iter:
|
63 |
+
partial_text += i
|
64 |
+
yield partial_text
|
65 |
+
else:
|
66 |
+
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
|
67 |
+
|
68 |
+
def _get_response(self, stream=False):
|
69 |
+
minimax_api_key = self.api_key
|
70 |
+
history = self.history
|
71 |
+
logging.debug(colorama.Fore.YELLOW +
|
72 |
+
f"{history}" + colorama.Fore.RESET)
|
73 |
+
headers = {
|
74 |
+
"Content-Type": "application/json",
|
75 |
+
"Authorization": f"Bearer {minimax_api_key}",
|
76 |
+
}
|
77 |
+
|
78 |
+
temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10
|
79 |
+
|
80 |
+
messages = []
|
81 |
+
for msg in self.history:
|
82 |
+
if msg['role'] == 'user':
|
83 |
+
messages.append({"sender_type": "USER", "text": msg['content']})
|
84 |
+
else:
|
85 |
+
messages.append({"sender_type": "BOT", "text": msg['content']})
|
86 |
+
|
87 |
+
request_body = {
|
88 |
+
"model": self.model_name.replace('minimax-', ''),
|
89 |
+
"temperature": temperature,
|
90 |
+
"skip_info_mask": True,
|
91 |
+
'messages': messages
|
92 |
+
}
|
93 |
+
if self.n_choices:
|
94 |
+
request_body['beam_width'] = self.n_choices
|
95 |
+
if self.system_prompt:
|
96 |
+
lines = self.system_prompt.splitlines()
|
97 |
+
if lines[0].find(":") != -1 and len(lines[0]) < 20:
|
98 |
+
request_body["role_meta"] = {
|
99 |
+
"user_name": lines[0].split(":")[0],
|
100 |
+
"bot_name": lines[0].split(":")[1]
|
101 |
+
}
|
102 |
+
lines.pop()
|
103 |
+
request_body["prompt"] = "\n".join(lines)
|
104 |
+
if self.max_generation_token:
|
105 |
+
request_body['tokens_to_generate'] = self.max_generation_token
|
106 |
+
else:
|
107 |
+
request_body['tokens_to_generate'] = 512
|
108 |
+
if self.top_p:
|
109 |
+
request_body['top_p'] = self.top_p
|
110 |
+
|
111 |
+
if stream:
|
112 |
+
timeout = TIMEOUT_STREAMING
|
113 |
+
request_body['stream'] = True
|
114 |
+
request_body['use_standard_sse'] = True
|
115 |
+
else:
|
116 |
+
timeout = TIMEOUT_ALL
|
117 |
+
try:
|
118 |
+
response = requests.post(
|
119 |
+
self.url,
|
120 |
+
headers=headers,
|
121 |
+
json=request_body,
|
122 |
+
stream=stream,
|
123 |
+
timeout=timeout,
|
124 |
+
)
|
125 |
+
except:
|
126 |
+
return None
|
127 |
+
|
128 |
+
return response
|
129 |
+
|
130 |
+
def _decode_chat_response(self, response):
|
131 |
+
error_msg = ""
|
132 |
+
for chunk in response.iter_lines():
|
133 |
+
if chunk:
|
134 |
+
chunk = chunk.decode()
|
135 |
+
chunk_length = len(chunk)
|
136 |
+
print(chunk)
|
137 |
+
try:
|
138 |
+
chunk = json.loads(chunk[6:])
|
139 |
+
except json.JSONDecodeError:
|
140 |
+
print(i18n("JSON解析错误,��到的内容: ") + f"{chunk}")
|
141 |
+
error_msg += chunk
|
142 |
+
continue
|
143 |
+
if chunk_length > 6 and "delta" in chunk["choices"][0]:
|
144 |
+
if "finish_reason" in chunk["choices"][0] and chunk["choices"][0]["finish_reason"] == "stop":
|
145 |
+
self.all_token_counts.append(chunk["usage"]["total_tokens"] - sum(self.all_token_counts))
|
146 |
+
break
|
147 |
+
try:
|
148 |
+
yield chunk["choices"][0]["delta"]
|
149 |
+
except Exception as e:
|
150 |
+
logging.error(f"Error: {e}")
|
151 |
+
continue
|
152 |
+
if error_msg:
|
153 |
+
try:
|
154 |
+
error_msg = json.loads(error_msg)
|
155 |
+
if 'base_resp' in error_msg:
|
156 |
+
status_code = error_msg['base_resp']['status_code']
|
157 |
+
status_msg = error_msg['base_resp']['status_msg']
|
158 |
+
raise Exception(f"{status_code} - {status_msg}")
|
159 |
+
except json.JSONDecodeError:
|
160 |
+
pass
|
161 |
+
raise Exception(error_msg)
|
modules/models/models.py
CHANGED
@@ -602,6 +602,11 @@ def get_model(
|
|
602 |
elif model_type == ModelType.YuanAI:
|
603 |
from .inspurai import Yuan_Client
|
604 |
model = Yuan_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
|
|
|
|
|
|
|
|
|
|
|
605 |
elif model_type == ModelType.Unknown:
|
606 |
raise ValueError(f"未知模型: {model_name}")
|
607 |
logging.info(msg)
|
|
|
602 |
elif model_type == ModelType.YuanAI:
|
603 |
from .inspurai import Yuan_Client
|
604 |
model = Yuan_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
|
605 |
+
elif model_type == ModelType.Minimax:
|
606 |
+
from .minimax import MiniMax_Client
|
607 |
+
if os.environ.get("MINIMAX_API_KEY") != "":
|
608 |
+
access_key = os.environ.get("MINIMAX_API_KEY")
|
609 |
+
model = MiniMax_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
|
610 |
elif model_type == ModelType.Unknown:
|
611 |
raise ValueError(f"未知模型: {model_name}")
|
612 |
logging.info(msg)
|
modules/presets.py
CHANGED
@@ -72,6 +72,8 @@ ONLINE_MODELS = [
|
|
72 |
"yuanai-1.0-translate",
|
73 |
"yuanai-1.0-dialog",
|
74 |
"yuanai-1.0-rhythm_poems",
|
|
|
|
|
75 |
]
|
76 |
|
77 |
LOCAL_MODELS = [
|
|
|
72 |
"yuanai-1.0-translate",
|
73 |
"yuanai-1.0-dialog",
|
74 |
"yuanai-1.0-rhythm_poems",
|
75 |
+
"minimax-abab4-chat",
|
76 |
+
"minimax-abab5-chat",
|
77 |
]
|
78 |
|
79 |
LOCAL_MODELS = [
|