Tuchuanhuhuhu commited on
Commit
3675c9f
2 Parent(s): a6c25bd b346648

Merge branch 'llamacpp'

Browse files
.gitignore CHANGED
@@ -141,6 +141,7 @@ api_key.txt
141
  config.json
142
  auth.json
143
  .models/
 
144
  lora/
145
  .idea
146
  templates/*
 
141
  config.json
142
  auth.json
143
  .models/
144
+ models/*
145
  lora/
146
  .idea
147
  templates/*
modules/models/{azure.py → Azure.py} RENAMED
File without changes
modules/models/ChatGLM.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import platform
6
+
7
+ import colorama
8
+
9
+ from ..index_func import *
10
+ from ..presets import *
11
+ from ..utils import *
12
+ from .base_model import BaseLLMModel
13
+
14
+
15
+ class ChatGLM_Client(BaseLLMModel):
16
+ def __init__(self, model_name, user_name="") -> None:
17
+ super().__init__(model_name=model_name, user=user_name)
18
+ import torch
19
+ from transformers import AutoModel, AutoTokenizer
20
+ global CHATGLM_TOKENIZER, CHATGLM_MODEL
21
+ if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
22
+ system_name = platform.system()
23
+ model_path = None
24
+ if os.path.exists("models"):
25
+ model_dirs = os.listdir("models")
26
+ if model_name in model_dirs:
27
+ model_path = f"models/{model_name}"
28
+ if model_path is not None:
29
+ model_source = model_path
30
+ else:
31
+ model_source = f"THUDM/{model_name}"
32
+ CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(
33
+ model_source, trust_remote_code=True
34
+ )
35
+ quantified = False
36
+ if "int4" in model_name:
37
+ quantified = True
38
+ model = AutoModel.from_pretrained(
39
+ model_source, trust_remote_code=True
40
+ )
41
+ if torch.cuda.is_available():
42
+ # run on CUDA
43
+ logging.info("CUDA is available, using CUDA")
44
+ model = model.half().cuda()
45
+ # mps加速还存在一些问题,暂时不使用
46
+ elif system_name == "Darwin" and model_path is not None and not quantified:
47
+ logging.info("Running on macOS, using MPS")
48
+ # running on macOS and model already downloaded
49
+ model = model.half().to("mps")
50
+ else:
51
+ logging.info("GPU is not available, using CPU")
52
+ model = model.float()
53
+ model = model.eval()
54
+ CHATGLM_MODEL = model
55
+
56
+ def _get_glm_style_input(self):
57
+ history = [x["content"] for x in self.history]
58
+ query = history.pop()
59
+ logging.debug(colorama.Fore.YELLOW +
60
+ f"{history}" + colorama.Fore.RESET)
61
+ assert (
62
+ len(history) % 2 == 0
63
+ ), f"History should be even length. current history is: {history}"
64
+ history = [[history[i], history[i + 1]]
65
+ for i in range(0, len(history), 2)]
66
+ return history, query
67
+
68
+ def get_answer_at_once(self):
69
+ history, query = self._get_glm_style_input()
70
+ response, _ = CHATGLM_MODEL.chat(
71
+ CHATGLM_TOKENIZER, query, history=history)
72
+ return response, len(response)
73
+
74
+ def get_answer_stream_iter(self):
75
+ history, query = self._get_glm_style_input()
76
+ for response, history in CHATGLM_MODEL.stream_chat(
77
+ CHATGLM_TOKENIZER,
78
+ query,
79
+ history,
80
+ max_length=self.token_upper_limit,
81
+ top_p=self.top_p,
82
+ temperature=self.temperature,
83
+ ):
84
+ yield response
modules/models/{Google_PaLM.py → GooglePaLM.py} RENAMED
@@ -1,6 +1,7 @@
1
  from .base_model import BaseLLMModel
2
  import google.generativeai as palm
3
 
 
4
  class Google_PaLM_Client(BaseLLMModel):
5
  def __init__(self, model_name, api_key, user_name="") -> None:
6
  super().__init__(model_name=model_name, user=user_name)
@@ -18,9 +19,11 @@ class Google_PaLM_Client(BaseLLMModel):
18
  def get_answer_at_once(self):
19
  palm.configure(api_key=self.api_key)
20
  messages = self._get_palm_style_input()
21
- response = palm.chat(context=self.system_prompt, messages=messages, temperature=self.temperature, top_p=self.top_p)
 
22
  if response.last is not None:
23
  return response.last, len(response.last)
24
  else:
25
- reasons = '\n\n'.join(reason['reason'].name for reason in response.filters)
26
- return "由于下面的原因,Google 拒绝返回 PaLM 的回答:\n\n" + reasons, 0
 
 
1
  from .base_model import BaseLLMModel
2
  import google.generativeai as palm
3
 
4
+
5
  class Google_PaLM_Client(BaseLLMModel):
6
  def __init__(self, model_name, api_key, user_name="") -> None:
7
  super().__init__(model_name=model_name, user=user_name)
 
19
  def get_answer_at_once(self):
20
  palm.configure(api_key=self.api_key)
21
  messages = self._get_palm_style_input()
22
+ response = palm.chat(context=self.system_prompt, messages=messages,
23
+ temperature=self.temperature, top_p=self.top_p)
24
  if response.last is not None:
25
  return response.last, len(response.last)
26
  else:
27
+ reasons = '\n\n'.join(
28
+ reason['reason'].name for reason in response.filters)
29
+ return "由于下面的原因,Google 拒绝返回 PaLM 的回答:\n\n" + reasons, 0
modules/models/LLaMA.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+
6
+ from huggingface_hub import hf_hub_download
7
+ from llama_cpp import Llama
8
+
9
+ from ..index_func import *
10
+ from ..presets import *
11
+ from ..utils import *
12
+ from .base_model import BaseLLMModel
13
+
14
+ SYS_PREFIX = "<<SYS>>\n"
15
+ SYS_POSTFIX = "\n<</SYS>>\n\n"
16
+ INST_PREFIX = "<s>[INST] "
17
+ INST_POSTFIX = " "
18
+ OUTPUT_PREFIX = "[/INST] "
19
+ OUTPUT_POSTFIX = "</s>"
20
+
21
+
22
+ def download(repo_id, filename, retry=10):
23
+ if os.path.exists("./models/downloaded_models.json"):
24
+ with open("./models/downloaded_models.json", "r") as f:
25
+ downloaded_models = json.load(f)
26
+ if repo_id in downloaded_models:
27
+ return downloaded_models[repo_id]["path"]
28
+ else:
29
+ downloaded_models = {}
30
+ while retry > 0:
31
+ try:
32
+ model_path = hf_hub_download(
33
+ repo_id=repo_id,
34
+ filename=filename,
35
+ cache_dir="models",
36
+ resume_download=True,
37
+ )
38
+ downloaded_models[repo_id] = {"path": model_path}
39
+ with open("./models/downloaded_models.json", "w") as f:
40
+ json.dump(downloaded_models, f)
41
+ break
42
+ except:
43
+ print("Error downloading model, retrying...")
44
+ retry -= 1
45
+ if retry == 0:
46
+ raise Exception("Error downloading model, please try again later.")
47
+ return model_path
48
+
49
+
50
+ class LLaMA_Client(BaseLLMModel):
51
+ def __init__(self, model_name, lora_path=None, user_name="") -> None:
52
+ super().__init__(model_name=model_name, user=user_name)
53
+
54
+ self.max_generation_token = 1000
55
+ if model_name in MODEL_METADATA:
56
+ path_to_model = download(
57
+ MODEL_METADATA[model_name]["repo_id"],
58
+ MODEL_METADATA[model_name]["filelist"][0],
59
+ )
60
+ else:
61
+ dir_to_model = os.path.join("models", model_name)
62
+ # look for nay .gguf file in the dir_to_model directory and its subdirectories
63
+ path_to_model = None
64
+ for root, dirs, files in os.walk(dir_to_model):
65
+ for file in files:
66
+ if file.endswith(".gguf"):
67
+ path_to_model = os.path.join(root, file)
68
+ break
69
+ if path_to_model is not None:
70
+ break
71
+ self.system_prompt = ""
72
+
73
+ if lora_path is not None:
74
+ lora_path = os.path.join("lora", lora_path)
75
+ self.model = Llama(model_path=path_to_model, lora_path=lora_path)
76
+ else:
77
+ self.model = Llama(model_path=path_to_model)
78
+
79
+ def _get_llama_style_input(self):
80
+ context = []
81
+ for conv in self.history:
82
+ if conv["role"] == "system":
83
+ context.append(SYS_PREFIX + conv["content"] + SYS_POSTFIX)
84
+ elif conv["role"] == "user":
85
+ context.append(
86
+ INST_PREFIX + conv["content"] + INST_POSTFIX + OUTPUT_PREFIX
87
+ )
88
+ else:
89
+ context.append(conv["content"] + OUTPUT_POSTFIX)
90
+ return "".join(context)
91
+
92
+ def get_answer_at_once(self):
93
+ context = self._get_llama_style_input()
94
+ response = self.model(
95
+ context,
96
+ max_tokens=self.max_generation_token,
97
+ stop=[],
98
+ echo=False,
99
+ stream=False,
100
+ )
101
+ return response, len(response)
102
+
103
+ def get_answer_stream_iter(self):
104
+ context = self._get_llama_style_input()
105
+ iter = self.model(
106
+ context,
107
+ max_tokens=self.max_generation_token,
108
+ stop=[],
109
+ echo=False,
110
+ stream=True,
111
+ )
112
+ partial_text = ""
113
+ for i in iter:
114
+ response = i["choices"][0]["text"]
115
+ partial_text += response
116
+ yield partial_text
modules/models/OpenAI.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+
6
+ import colorama
7
+ import requests
8
+
9
+ from .. import shared
10
+ from ..config import retrieve_proxy, sensitive_id, usage_limit
11
+ from ..index_func import *
12
+ from ..presets import *
13
+ from ..utils import *
14
+ from .base_model import BaseLLMModel
15
+
16
+
17
+ class OpenAIClient(BaseLLMModel):
18
+ def __init__(
19
+ self,
20
+ model_name,
21
+ api_key,
22
+ system_prompt=INITIAL_SYSTEM_PROMPT,
23
+ temperature=1.0,
24
+ top_p=1.0,
25
+ user_name=""
26
+ ) -> None:
27
+ super().__init__(
28
+ model_name=model_name,
29
+ temperature=temperature,
30
+ top_p=top_p,
31
+ system_prompt=system_prompt,
32
+ user=user_name
33
+ )
34
+ self.api_key = api_key
35
+ self.need_api_key = True
36
+ self._refresh_header()
37
+
38
+ def get_answer_stream_iter(self):
39
+ response = self._get_response(stream=True)
40
+ if response is not None:
41
+ iter = self._decode_chat_response(response)
42
+ partial_text = ""
43
+ for i in iter:
44
+ partial_text += i
45
+ yield partial_text
46
+ else:
47
+ yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
48
+
49
+ def get_answer_at_once(self):
50
+ response = self._get_response()
51
+ response = json.loads(response.text)
52
+ content = response["choices"][0]["message"]["content"]
53
+ total_token_count = response["usage"]["total_tokens"]
54
+ return content, total_token_count
55
+
56
+ def count_token(self, user_input):
57
+ input_token_count = count_token(construct_user(user_input))
58
+ if self.system_prompt is not None and len(self.all_token_counts) == 0:
59
+ system_prompt_token_count = count_token(
60
+ construct_system(self.system_prompt)
61
+ )
62
+ return input_token_count + system_prompt_token_count
63
+ return input_token_count
64
+
65
+ def billing_info(self):
66
+ try:
67
+ curr_time = datetime.datetime.now()
68
+ last_day_of_month = get_last_day_of_month(
69
+ curr_time).strftime("%Y-%m-%d")
70
+ first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
71
+ usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
72
+ try:
73
+ usage_data = self._get_billing_data(usage_url)
74
+ except Exception as e:
75
+ # logging.error(f"获取API使用情况失败: " + str(e))
76
+ if "Invalid authorization header" in str(e):
77
+ return i18n("**获取API使用情况失败**,需在填写`config.json`中正确填写sensitive_id")
78
+ elif "Incorrect API key provided: sess" in str(e):
79
+ return i18n("**获取API使用情况失败**,sensitive_id错误或已过期")
80
+ return i18n("**获取API使用情况失败**")
81
+ # rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
82
+ rounded_usage = round(usage_data["total_usage"] / 100, 5)
83
+ usage_percent = round(usage_data["total_usage"] / usage_limit, 2)
84
+ from ..webui import get_html
85
+
86
+ # return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
87
+ return get_html("billing_info.html").format(
88
+ label=i18n("本月使用金额"),
89
+ usage_percent=usage_percent,
90
+ rounded_usage=rounded_usage,
91
+ usage_limit=usage_limit
92
+ )
93
+ except requests.exceptions.ConnectTimeout:
94
+ status_text = (
95
+ STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
96
+ )
97
+ return status_text
98
+ except requests.exceptions.ReadTimeout:
99
+ status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
100
+ return status_text
101
+ except Exception as e:
102
+ import traceback
103
+ traceback.print_exc()
104
+ logging.error(i18n("获取API使用情况失败:") + str(e))
105
+ return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
106
+
107
+ def set_token_upper_limit(self, new_upper_limit):
108
+ pass
109
+
110
+ @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
111
+ def _get_response(self, stream=False):
112
+ openai_api_key = self.api_key
113
+ system_prompt = self.system_prompt
114
+ history = self.history
115
+ logging.debug(colorama.Fore.YELLOW +
116
+ f"{history}" + colorama.Fore.RESET)
117
+ headers = {
118
+ "Content-Type": "application/json",
119
+ "Authorization": f"Bearer {openai_api_key}",
120
+ }
121
+
122
+ if system_prompt is not None:
123
+ history = [construct_system(system_prompt), *history]
124
+
125
+ payload = {
126
+ "model": self.model_name,
127
+ "messages": history,
128
+ "temperature": self.temperature,
129
+ "top_p": self.top_p,
130
+ "n": self.n_choices,
131
+ "stream": stream,
132
+ "presence_penalty": self.presence_penalty,
133
+ "frequency_penalty": self.frequency_penalty,
134
+ }
135
+
136
+ if self.max_generation_token is not None:
137
+ payload["max_tokens"] = self.max_generation_token
138
+ if self.stop_sequence is not None:
139
+ payload["stop"] = self.stop_sequence
140
+ if self.logit_bias is not None:
141
+ payload["logit_bias"] = self.logit_bias
142
+ if self.user_identifier:
143
+ payload["user"] = self.user_identifier
144
+
145
+ if stream:
146
+ timeout = TIMEOUT_STREAMING
147
+ else:
148
+ timeout = TIMEOUT_ALL
149
+
150
+ # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
151
+ if shared.state.completion_url != COMPLETION_URL:
152
+ logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
153
+
154
+ with retrieve_proxy():
155
+ try:
156
+ response = requests.post(
157
+ shared.state.completion_url,
158
+ headers=headers,
159
+ json=payload,
160
+ stream=stream,
161
+ timeout=timeout,
162
+ )
163
+ except:
164
+ return None
165
+ return response
166
+
167
+ def _refresh_header(self):
168
+ self.headers = {
169
+ "Content-Type": "application/json",
170
+ "Authorization": f"Bearer {sensitive_id}",
171
+ }
172
+
173
+ def _get_billing_data(self, billing_url):
174
+ with retrieve_proxy():
175
+ response = requests.get(
176
+ billing_url,
177
+ headers=self.headers,
178
+ timeout=TIMEOUT_ALL,
179
+ )
180
+
181
+ if response.status_code == 200:
182
+ data = response.json()
183
+ return data
184
+ else:
185
+ raise Exception(
186
+ f"API request failed with status code {response.status_code}: {response.text}"
187
+ )
188
+
189
+ def _decode_chat_response(self, response):
190
+ error_msg = ""
191
+ for chunk in response.iter_lines():
192
+ if chunk:
193
+ chunk = chunk.decode()
194
+ chunk_length = len(chunk)
195
+ try:
196
+ chunk = json.loads(chunk[6:])
197
+ except:
198
+ print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
199
+ error_msg += chunk
200
+ continue
201
+ if chunk_length > 6 and "delta" in chunk["choices"][0]:
202
+ if chunk["choices"][0]["finish_reason"] == "stop":
203
+ break
204
+ try:
205
+ yield chunk["choices"][0]["delta"]["content"]
206
+ except Exception as e:
207
+ # logging.error(f"Error: {e}")
208
+ continue
209
+ if error_msg:
210
+ raise Exception(error_msg)
211
+
212
+ def set_key(self, new_access_key):
213
+ ret = super().set_key(new_access_key)
214
+ self._refresh_header()
215
+ return ret
216
+
217
+ def _single_query_at_once(self, history, temperature=1.0):
218
+ timeout = TIMEOUT_ALL
219
+ headers = {
220
+ "Content-Type": "application/json",
221
+ "Authorization": f"Bearer {self.api_key}",
222
+ "temperature": f"{temperature}",
223
+ }
224
+ payload = {
225
+ "model": self.model_name,
226
+ "messages": history,
227
+ }
228
+ # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
229
+ if shared.state.completion_url != COMPLETION_URL:
230
+ logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
231
+
232
+ with retrieve_proxy():
233
+ response = requests.post(
234
+ shared.state.completion_url,
235
+ headers=headers,
236
+ json=payload,
237
+ stream=False,
238
+ timeout=timeout,
239
+ )
240
+
241
+ return response
242
+
243
+ def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
244
+ if len(self.history) == 2 and not single_turn_checkbox:
245
+ user_question = self.history[0]["content"]
246
+ if name_chat_method == i18n("模型自动总结(消耗tokens)"):
247
+ ai_answer = self.history[1]["content"]
248
+ try:
249
+ history = [
250
+ {"role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
251
+ {"role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
252
+ ]
253
+ response = self._single_query_at_once(
254
+ history, temperature=0.0)
255
+ response = json.loads(response.text)
256
+ content = response["choices"][0]["message"]["content"]
257
+ filename = replace_special_symbols(content) + ".json"
258
+ except Exception as e:
259
+ logging.info(f"自动命名失败。{e}")
260
+ filename = replace_special_symbols(user_question)[
261
+ :16] + ".json"
262
+ return self.rename_chat_history(filename, chatbot, user_name)
263
+ elif name_chat_method == i18n("第一条提问"):
264
+ filename = replace_special_symbols(user_question)[
265
+ :16] + ".json"
266
+ return self.rename_chat_history(filename, chatbot, user_name)
267
+ else:
268
+ return gr.update()
269
+ else:
270
+ return gr.update()
modules/models/XMChat.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ import logging
6
+ import os
7
+ import uuid
8
+ from io import BytesIO
9
+
10
+ import requests
11
+ from PIL import Image
12
+
13
+ from ..index_func import *
14
+ from ..presets import *
15
+ from ..utils import *
16
+ from .base_model import BaseLLMModel
17
+
18
+
19
+ class XMChatClient(BaseLLMModel):
20
+ def __init__(self, api_key, user_name=""):
21
+ super().__init__(model_name="xmchat", user=user_name)
22
+ self.api_key = api_key
23
+ self.session_id = None
24
+ self.reset()
25
+ self.image_bytes = None
26
+ self.image_path = None
27
+ self.xm_history = []
28
+ self.url = "https://xmbot.net/web"
29
+ self.last_conv_id = None
30
+
31
+ def reset(self):
32
+ self.session_id = str(uuid.uuid4())
33
+ self.last_conv_id = None
34
+ return [], "已重置"
35
+
36
+ def image_to_base64(self, image_path):
37
+ # 打开并加载图片
38
+ img = Image.open(image_path)
39
+
40
+ # 获取图片的宽度和高度
41
+ width, height = img.size
42
+
43
+ # 计算压缩比例,以确保最长边小于4096像素
44
+ max_dimension = 2048
45
+ scale_ratio = min(max_dimension / width, max_dimension / height)
46
+
47
+ if scale_ratio < 1:
48
+ # 按压缩比例调整图片大小
49
+ new_width = int(width * scale_ratio)
50
+ new_height = int(height * scale_ratio)
51
+ img = img.resize((new_width, new_height), Image.ANTIALIAS)
52
+
53
+ # 将图片转换为jpg格式的二进制数据
54
+ buffer = BytesIO()
55
+ if img.mode == "RGBA":
56
+ img = img.convert("RGB")
57
+ img.save(buffer, format='JPEG')
58
+ binary_image = buffer.getvalue()
59
+
60
+ # 对二进制数据进行Base64编码
61
+ base64_image = base64.b64encode(binary_image).decode('utf-8')
62
+
63
+ return base64_image
64
+
65
+ def try_read_image(self, filepath):
66
+ def is_image_file(filepath):
67
+ # 判断文件是否为图片
68
+ valid_image_extensions = [
69
+ ".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
70
+ file_extension = os.path.splitext(filepath)[1].lower()
71
+ return file_extension in valid_image_extensions
72
+
73
+ if is_image_file(filepath):
74
+ logging.info(f"读取图片文件: {filepath}")
75
+ self.image_bytes = self.image_to_base64(filepath)
76
+ self.image_path = filepath
77
+ else:
78
+ self.image_bytes = None
79
+ self.image_path = None
80
+
81
+ def like(self):
82
+ if self.last_conv_id is None:
83
+ return "点赞失败,你还没发送过消息"
84
+ data = {
85
+ "uuid": self.last_conv_id,
86
+ "appraise": "good"
87
+ }
88
+ requests.post(self.url, json=data)
89
+ return "👍点赞成功,感谢反馈~"
90
+
91
+ def dislike(self):
92
+ if self.last_conv_id is None:
93
+ return "点踩失败,你还没发送过消息"
94
+ data = {
95
+ "uuid": self.last_conv_id,
96
+ "appraise": "bad"
97
+ }
98
+ requests.post(self.url, json=data)
99
+ return "👎点踩成功,感谢反馈~"
100
+
101
+ def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
102
+ fake_inputs = real_inputs
103
+ display_append = ""
104
+ limited_context = False
105
+ return limited_context, fake_inputs, display_append, real_inputs, chatbot
106
+
107
+ def handle_file_upload(self, files, chatbot, language):
108
+ """if the model accepts multi modal input, implement this function"""
109
+ if files:
110
+ for file in files:
111
+ if file.name:
112
+ logging.info(f"尝试读取图像: {file.name}")
113
+ self.try_read_image(file.name)
114
+ if self.image_path is not None:
115
+ chatbot = chatbot + [((self.image_path,), None)]
116
+ if self.image_bytes is not None:
117
+ logging.info("使用图片作为输入")
118
+ # XMChat的一轮对话中实际上只能处理一张图片
119
+ self.reset()
120
+ conv_id = str(uuid.uuid4())
121
+ data = {
122
+ "user_id": self.api_key,
123
+ "session_id": self.session_id,
124
+ "uuid": conv_id,
125
+ "data_type": "imgbase64",
126
+ "data": self.image_bytes
127
+ }
128
+ response = requests.post(self.url, json=data)
129
+ response = json.loads(response.text)
130
+ logging.info(f"图片回复: {response['data']}")
131
+ return None, chatbot, None
132
+
133
+ def get_answer_at_once(self):
134
+ question = self.history[-1]["content"]
135
+ conv_id = str(uuid.uuid4())
136
+ self.last_conv_id = conv_id
137
+ data = {
138
+ "user_id": self.api_key,
139
+ "session_id": self.session_id,
140
+ "uuid": conv_id,
141
+ "data_type": "text",
142
+ "data": question
143
+ }
144
+ response = requests.post(self.url, json=data)
145
+ try:
146
+ response = json.loads(response.text)
147
+ return response["data"], len(response["data"])
148
+ except Exception as e:
149
+ return response.text, len(response.text)
modules/models/models.py CHANGED
@@ -1,597 +1,19 @@
1
  from __future__ import annotations
2
 
3
- import base64
4
- import json
5
  import logging
6
  import os
7
- import platform
8
- import traceback
9
- import uuid
10
- from io import BytesIO
11
 
12
  import colorama
13
  import commentjson as cjson
14
- import requests
15
- from PIL import Image
16
 
17
  from modules import config
18
 
19
- from .. import shared
20
- from ..config import retrieve_proxy, sensitive_id, usage_limit
21
  from ..index_func import *
22
  from ..presets import *
23
  from ..utils import *
24
  from .base_model import BaseLLMModel, ModelType
25
 
26
 
27
- class OpenAIClient(BaseLLMModel):
28
- def __init__(
29
- self,
30
- model_name,
31
- api_key,
32
- system_prompt=INITIAL_SYSTEM_PROMPT,
33
- temperature=1.0,
34
- top_p=1.0,
35
- user_name=""
36
- ) -> None:
37
- super().__init__(
38
- model_name=model_name,
39
- temperature=temperature,
40
- top_p=top_p,
41
- system_prompt=system_prompt,
42
- user=user_name
43
- )
44
- self.api_key = api_key
45
- self.need_api_key = True
46
- self._refresh_header()
47
-
48
- def get_answer_stream_iter(self):
49
- response = self._get_response(stream=True)
50
- if response is not None:
51
- iter = self._decode_chat_response(response)
52
- partial_text = ""
53
- for i in iter:
54
- partial_text += i
55
- yield partial_text
56
- else:
57
- yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
58
-
59
- def get_answer_at_once(self):
60
- response = self._get_response()
61
- response = json.loads(response.text)
62
- content = response["choices"][0]["message"]["content"]
63
- total_token_count = response["usage"]["total_tokens"]
64
- return content, total_token_count
65
-
66
- def count_token(self, user_input):
67
- input_token_count = count_token(construct_user(user_input))
68
- if self.system_prompt is not None and len(self.all_token_counts) == 0:
69
- system_prompt_token_count = count_token(
70
- construct_system(self.system_prompt)
71
- )
72
- return input_token_count + system_prompt_token_count
73
- return input_token_count
74
-
75
- def billing_info(self):
76
- try:
77
- curr_time = datetime.datetime.now()
78
- last_day_of_month = get_last_day_of_month(
79
- curr_time).strftime("%Y-%m-%d")
80
- first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
81
- usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
82
- try:
83
- usage_data = self._get_billing_data(usage_url)
84
- except Exception as e:
85
- # logging.error(f"获取API使用情况失败: " + str(e))
86
- if "Invalid authorization header" in str(e):
87
- return i18n("**获取API使用情况失败**,需在填写`config.json`中正确填写sensitive_id")
88
- elif "Incorrect API key provided: sess" in str(e):
89
- return i18n("**获取API使用情况失败**,sensitive_id错误或已过期")
90
- return i18n("**获取API使用情况失败**")
91
- # rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
92
- rounded_usage = round(usage_data["total_usage"] / 100, 5)
93
- usage_percent = round(usage_data["total_usage"] / usage_limit, 2)
94
- from ..webui import get_html
95
-
96
- # return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
97
- return get_html("billing_info.html").format(
98
- label = i18n("本月使用金额"),
99
- usage_percent = usage_percent,
100
- rounded_usage = rounded_usage,
101
- usage_limit = usage_limit
102
- )
103
- except requests.exceptions.ConnectTimeout:
104
- status_text = (
105
- STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
106
- )
107
- return status_text
108
- except requests.exceptions.ReadTimeout:
109
- status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
110
- return status_text
111
- except Exception as e:
112
- import traceback
113
- traceback.print_exc()
114
- logging.error(i18n("获取API使用情况失败:") + str(e))
115
- return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
116
-
117
- def set_token_upper_limit(self, new_upper_limit):
118
- pass
119
-
120
- @shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
121
- def _get_response(self, stream=False):
122
- openai_api_key = self.api_key
123
- system_prompt = self.system_prompt
124
- history = self.history
125
- logging.debug(colorama.Fore.YELLOW +
126
- f"{history}" + colorama.Fore.RESET)
127
- headers = {
128
- "Content-Type": "application/json",
129
- "Authorization": f"Bearer {openai_api_key}",
130
- }
131
-
132
- if system_prompt is not None:
133
- history = [construct_system(system_prompt), *history]
134
-
135
- payload = {
136
- "model": self.model_name,
137
- "messages": history,
138
- "temperature": self.temperature,
139
- "top_p": self.top_p,
140
- "n": self.n_choices,
141
- "stream": stream,
142
- "presence_penalty": self.presence_penalty,
143
- "frequency_penalty": self.frequency_penalty,
144
- }
145
-
146
- if self.max_generation_token is not None:
147
- payload["max_tokens"] = self.max_generation_token
148
- if self.stop_sequence is not None:
149
- payload["stop"] = self.stop_sequence
150
- if self.logit_bias is not None:
151
- payload["logit_bias"] = self.logit_bias
152
- if self.user_identifier:
153
- payload["user"] = self.user_identifier
154
-
155
- if stream:
156
- timeout = TIMEOUT_STREAMING
157
- else:
158
- timeout = TIMEOUT_ALL
159
-
160
- # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
161
- if shared.state.completion_url != COMPLETION_URL:
162
- logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
163
-
164
- with retrieve_proxy():
165
- try:
166
- response = requests.post(
167
- shared.state.completion_url,
168
- headers=headers,
169
- json=payload,
170
- stream=stream,
171
- timeout=timeout,
172
- )
173
- except:
174
- traceback.print_exc()
175
- return None
176
- return response
177
-
178
- def _refresh_header(self):
179
- self.headers = {
180
- "Content-Type": "application/json",
181
- "Authorization": f"Bearer {sensitive_id}",
182
- }
183
-
184
-
185
- def _get_billing_data(self, billing_url):
186
- with retrieve_proxy():
187
- response = requests.get(
188
- billing_url,
189
- headers=self.headers,
190
- timeout=TIMEOUT_ALL,
191
- )
192
-
193
- if response.status_code == 200:
194
- data = response.json()
195
- return data
196
- else:
197
- raise Exception(
198
- f"API request failed with status code {response.status_code}: {response.text}"
199
- )
200
-
201
- def _decode_chat_response(self, response):
202
- error_msg = ""
203
- for chunk in response.iter_lines():
204
- if chunk:
205
- chunk = chunk.decode()
206
- chunk_length = len(chunk)
207
- try:
208
- chunk = json.loads(chunk[6:])
209
- except:
210
- print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
211
- error_msg += chunk
212
- continue
213
- if chunk_length > 6 and "delta" in chunk["choices"][0]:
214
- if chunk["choices"][0]["finish_reason"] == "stop":
215
- break
216
- try:
217
- yield chunk["choices"][0]["delta"]["content"]
218
- except Exception as e:
219
- # logging.error(f"Error: {e}")
220
- continue
221
- if error_msg:
222
- raise Exception(error_msg)
223
-
224
- def set_key(self, new_access_key):
225
- ret = super().set_key(new_access_key)
226
- self._refresh_header()
227
- return ret
228
-
229
- def _single_query_at_once(self, history, temperature=1.0):
230
- timeout = TIMEOUT_ALL
231
- headers = {
232
- "Content-Type": "application/json",
233
- "Authorization": f"Bearer {self.api_key}",
234
- "temperature": f"{temperature}",
235
- }
236
- payload = {
237
- "model": self.model_name,
238
- "messages": history,
239
- }
240
- # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
241
- if shared.state.completion_url != COMPLETION_URL:
242
- logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
243
-
244
- with retrieve_proxy():
245
- response = requests.post(
246
- shared.state.completion_url,
247
- headers=headers,
248
- json=payload,
249
- stream=False,
250
- timeout=timeout,
251
- )
252
-
253
- return response
254
-
255
-
256
- def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
257
- if len(self.history) == 2 and not single_turn_checkbox:
258
- user_question = self.history[0]["content"]
259
- if name_chat_method == i18n("模型自动总结(消耗tokens)"):
260
- ai_answer = self.history[1]["content"]
261
- try:
262
- history = [
263
- { "role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
264
- { "role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
265
- ]
266
- response = self._single_query_at_once(history, temperature=0.0)
267
- response = json.loads(response.text)
268
- content = response["choices"][0]["message"]["content"]
269
- filename = replace_special_symbols(content) + ".json"
270
- except Exception as e:
271
- logging.info(f"自动命名失败。{e}")
272
- filename = replace_special_symbols(user_question)[:16] + ".json"
273
- return self.rename_chat_history(filename, chatbot, user_name)
274
- elif name_chat_method == i18n("第一条提问"):
275
- filename = replace_special_symbols(user_question)[:16] + ".json"
276
- return self.rename_chat_history(filename, chatbot, user_name)
277
- else:
278
- return gr.update()
279
- else:
280
- return gr.update()
281
-
282
-
283
- class ChatGLM_Client(BaseLLMModel):
284
- def __init__(self, model_name, user_name="") -> None:
285
- super().__init__(model_name=model_name, user=user_name)
286
- import torch
287
- from transformers import AutoModel, AutoTokenizer
288
- global CHATGLM_TOKENIZER, CHATGLM_MODEL
289
- if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
290
- system_name = platform.system()
291
- model_path = None
292
- if os.path.exists("models"):
293
- model_dirs = os.listdir("models")
294
- if model_name in model_dirs:
295
- model_path = f"models/{model_name}"
296
- if model_path is not None:
297
- model_source = model_path
298
- else:
299
- model_source = f"THUDM/{model_name}"
300
- CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(
301
- model_source, trust_remote_code=True
302
- )
303
- quantified = False
304
- if "int4" in model_name:
305
- quantified = True
306
- model = AutoModel.from_pretrained(
307
- model_source, trust_remote_code=True
308
- )
309
- if torch.cuda.is_available():
310
- # run on CUDA
311
- logging.info("CUDA is available, using CUDA")
312
- model = model.half().cuda()
313
- # mps加速还存在一些问题,暂时不使用
314
- elif system_name == "Darwin" and model_path is not None and not quantified:
315
- logging.info("Running on macOS, using MPS")
316
- # running on macOS and model already downloaded
317
- model = model.half().to("mps")
318
- else:
319
- logging.info("GPU is not available, using CPU")
320
- model = model.float()
321
- model = model.eval()
322
- CHATGLM_MODEL = model
323
-
324
- def _get_glm_style_input(self):
325
- history = [x["content"] for x in self.history]
326
- query = history.pop()
327
- logging.debug(colorama.Fore.YELLOW +
328
- f"{history}" + colorama.Fore.RESET)
329
- assert (
330
- len(history) % 2 == 0
331
- ), f"History should be even length. current history is: {history}"
332
- history = [[history[i], history[i + 1]]
333
- for i in range(0, len(history), 2)]
334
- return history, query
335
-
336
- def get_answer_at_once(self):
337
- history, query = self._get_glm_style_input()
338
- response, _ = CHATGLM_MODEL.chat(
339
- CHATGLM_TOKENIZER, query, history=history)
340
- return response, len(response)
341
-
342
- def get_answer_stream_iter(self):
343
- history, query = self._get_glm_style_input()
344
- for response, history in CHATGLM_MODEL.stream_chat(
345
- CHATGLM_TOKENIZER,
346
- query,
347
- history,
348
- max_length=self.token_upper_limit,
349
- top_p=self.top_p,
350
- temperature=self.temperature,
351
- ):
352
- yield response
353
-
354
-
355
- class LLaMA_Client(BaseLLMModel):
356
- def __init__(
357
- self,
358
- model_name,
359
- lora_path=None,
360
- user_name=""
361
- ) -> None:
362
- super().__init__(model_name=model_name, user=user_name)
363
- from lmflow.args import (DatasetArguments, InferencerArguments,
364
- ModelArguments)
365
- from lmflow.datasets.dataset import Dataset
366
- from lmflow.models.auto_model import AutoModel
367
- from lmflow.pipeline.auto_pipeline import AutoPipeline
368
-
369
- self.max_generation_token = 1000
370
- self.end_string = "\n\n"
371
- # We don't need input data
372
- data_args = DatasetArguments(dataset_path=None)
373
- self.dataset = Dataset(data_args)
374
- self.system_prompt = ""
375
-
376
- global LLAMA_MODEL, LLAMA_INFERENCER
377
- if LLAMA_MODEL is None or LLAMA_INFERENCER is None:
378
- model_path = None
379
- if os.path.exists("models"):
380
- model_dirs = os.listdir("models")
381
- if model_name in model_dirs:
382
- model_path = f"models/{model_name}"
383
- if model_path is not None:
384
- model_source = model_path
385
- else:
386
- model_source = f"decapoda-research/{model_name}"
387
- # raise Exception(f"models目录下没有这个模型: {model_name}")
388
- if lora_path is not None:
389
- lora_path = f"lora/{lora_path}"
390
- model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None,
391
- use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
392
- pipeline_args = InferencerArguments(
393
- local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
394
-
395
- with open(pipeline_args.deepspeed, "r", encoding="utf-8") as f:
396
- ds_config = json.load(f)
397
- LLAMA_MODEL = AutoModel.get_model(
398
- model_args,
399
- tune_strategy="none",
400
- ds_config=ds_config,
401
- )
402
- LLAMA_INFERENCER = AutoPipeline.get_pipeline(
403
- pipeline_name="inferencer",
404
- model_args=model_args,
405
- data_args=data_args,
406
- pipeline_args=pipeline_args,
407
- )
408
-
409
- def _get_llama_style_input(self):
410
- history = []
411
- instruction = ""
412
- if self.system_prompt:
413
- instruction = (f"Instruction: {self.system_prompt}\n")
414
- for x in self.history:
415
- if x["role"] == "user":
416
- history.append(f"{instruction}Input: {x['content']}")
417
- else:
418
- history.append(f"Output: {x['content']}")
419
- context = "\n\n".join(history)
420
- context += "\n\nOutput: "
421
- return context
422
-
423
- def get_answer_at_once(self):
424
- context = self._get_llama_style_input()
425
-
426
- input_dataset = self.dataset.from_dict(
427
- {"type": "text_only", "instances": [{"text": context}]}
428
- )
429
-
430
- output_dataset = LLAMA_INFERENCER.inference(
431
- model=LLAMA_MODEL,
432
- dataset=input_dataset,
433
- max_new_tokens=self.max_generation_token,
434
- temperature=self.temperature,
435
- )
436
-
437
- response = output_dataset.to_dict()["instances"][0]["text"]
438
- return response, len(response)
439
-
440
- def get_answer_stream_iter(self):
441
- context = self._get_llama_style_input()
442
- partial_text = ""
443
- step = 1
444
- for _ in range(0, self.max_generation_token, step):
445
- input_dataset = self.dataset.from_dict(
446
- {"type": "text_only", "instances": [
447
- {"text": context + partial_text}]}
448
- )
449
- output_dataset = LLAMA_INFERENCER.inference(
450
- model=LLAMA_MODEL,
451
- dataset=input_dataset,
452
- max_new_tokens=step,
453
- temperature=self.temperature,
454
- )
455
- response = output_dataset.to_dict()["instances"][0]["text"]
456
- if response == "" or response == self.end_string:
457
- break
458
- partial_text += response
459
- yield partial_text
460
-
461
-
462
- class XMChat(BaseLLMModel):
463
- def __init__(self, api_key, user_name=""):
464
- super().__init__(model_name="xmchat", user=user_name)
465
- self.api_key = api_key
466
- self.session_id = None
467
- self.reset()
468
- self.image_bytes = None
469
- self.image_path = None
470
- self.xm_history = []
471
- self.url = "https://xmbot.net/web"
472
- self.last_conv_id = None
473
-
474
- def reset(self):
475
- self.session_id = str(uuid.uuid4())
476
- self.last_conv_id = None
477
- return super().reset()
478
-
479
- def image_to_base64(self, image_path):
480
- # 打开并加载图片
481
- img = Image.open(image_path)
482
-
483
- # 获取图片的宽度和高度
484
- width, height = img.size
485
-
486
- # 计算压缩比例,以确保最长边小于4096像素
487
- max_dimension = 2048
488
- scale_ratio = min(max_dimension / width, max_dimension / height)
489
-
490
- if scale_ratio < 1:
491
- # 按压缩比例调整图片大小
492
- new_width = int(width * scale_ratio)
493
- new_height = int(height * scale_ratio)
494
- img = img.resize((new_width, new_height), Image.ANTIALIAS)
495
-
496
- # 将图片转换为jpg格式的二进制数据
497
- buffer = BytesIO()
498
- if img.mode == "RGBA":
499
- img = img.convert("RGB")
500
- img.save(buffer, format='JPEG')
501
- binary_image = buffer.getvalue()
502
-
503
- # 对二进制数据进行Base64编码
504
- base64_image = base64.b64encode(binary_image).decode('utf-8')
505
-
506
- return base64_image
507
-
508
- def try_read_image(self, filepath):
509
- def is_image_file(filepath):
510
- # 判断文件是否为图片
511
- valid_image_extensions = [
512
- ".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
513
- file_extension = os.path.splitext(filepath)[1].lower()
514
- return file_extension in valid_image_extensions
515
-
516
- if is_image_file(filepath):
517
- logging.info(f"读取图片文件: {filepath}")
518
- self.image_bytes = self.image_to_base64(filepath)
519
- self.image_path = filepath
520
- else:
521
- self.image_bytes = None
522
- self.image_path = None
523
-
524
- def like(self):
525
- if self.last_conv_id is None:
526
- return "点赞失败,你还没发送过消息"
527
- data = {
528
- "uuid": self.last_conv_id,
529
- "appraise": "good"
530
- }
531
- requests.post(self.url, json=data)
532
- return "👍点赞成功,感谢反馈~"
533
-
534
- def dislike(self):
535
- if self.last_conv_id is None:
536
- return "点踩失败,你还没发送过消息"
537
- data = {
538
- "uuid": self.last_conv_id,
539
- "appraise": "bad"
540
- }
541
- requests.post(self.url, json=data)
542
- return "👎点踩成功,感谢反馈~"
543
-
544
- def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
545
- fake_inputs = real_inputs
546
- display_append = ""
547
- limited_context = False
548
- return limited_context, fake_inputs, display_append, real_inputs, chatbot
549
-
550
- def handle_file_upload(self, files, chatbot, language):
551
- """if the model accepts multi modal input, implement this function"""
552
- if files:
553
- for file in files:
554
- if file.name:
555
- logging.info(f"尝试读取图像: {file.name}")
556
- self.try_read_image(file.name)
557
- if self.image_path is not None:
558
- chatbot = chatbot + [((self.image_path,), None)]
559
- if self.image_bytes is not None:
560
- logging.info("使用图片作为输入")
561
- # XMChat的一轮对话中实际上只能处理一张图片
562
- self.reset()
563
- conv_id = str(uuid.uuid4())
564
- data = {
565
- "user_id": self.api_key,
566
- "session_id": self.session_id,
567
- "uuid": conv_id,
568
- "data_type": "imgbase64",
569
- "data": self.image_bytes
570
- }
571
- response = requests.post(self.url, json=data)
572
- response = json.loads(response.text)
573
- logging.info(f"图片回复: {response['data']}")
574
- return None, chatbot, None
575
-
576
- def get_answer_at_once(self):
577
- question = self.history[-1]["content"]
578
- conv_id = str(uuid.uuid4())
579
- self.last_conv_id = conv_id
580
- data = {
581
- "user_id": self.api_key,
582
- "session_id": self.session_id,
583
- "uuid": conv_id,
584
- "data_type": "text",
585
- "data": question
586
- }
587
- response = requests.post(self.url, json=data)
588
- try:
589
- response = json.loads(response.text)
590
- return response["data"], len(response["data"])
591
- except Exception as e:
592
- return response.text, len(response.text)
593
-
594
-
595
  def get_model(
596
  model_name,
597
  lora_model_path=None,
@@ -605,7 +27,7 @@ def get_model(
605
  msg = i18n("模型设置为了:") + f" {model_name}"
606
  model_type = ModelType.get_type(model_name)
607
  lora_selector_visibility = False
608
- lora_choices = []
609
  dont_change_lora_selector = False
610
  if model_type != ModelType.OpenAI:
611
  config.local_embedding = True
@@ -615,6 +37,7 @@ def get_model(
615
  try:
616
  if model_type == ModelType.OpenAI:
617
  logging.info(f"正在加载OpenAI模型: {model_name}")
 
618
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
619
  model = OpenAIClient(
620
  model_name=model_name,
@@ -626,16 +49,17 @@ def get_model(
626
  )
627
  elif model_type == ModelType.ChatGLM:
628
  logging.info(f"正在加载ChatGLM模型: {model_name}")
 
629
  model = ChatGLM_Client(model_name, user_name=user_name)
630
  elif model_type == ModelType.LLaMA and lora_model_path == "":
631
  msg = f"现在请为 {model_name} 选择LoRA模型"
632
  logging.info(msg)
633
  lora_selector_visibility = True
634
  if os.path.isdir("lora"):
635
- get_file_names_by_pinyin("lora", filetypes=[""])
636
- lora_choices = ["No LoRA"] + lora_choices
637
  elif model_type == ModelType.LLaMA and lora_model_path != "":
638
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
 
639
  dont_change_lora_selector = True
640
  if lora_model_path == "No LoRA":
641
  lora_model_path = None
@@ -645,9 +69,10 @@ def get_model(
645
  model = LLaMA_Client(
646
  model_name, lora_model_path, user_name=user_name)
647
  elif model_type == ModelType.XMChat:
 
648
  if os.environ.get("XMCHAT_API_KEY") != "":
649
  access_key = os.environ.get("XMCHAT_API_KEY")
650
- model = XMChat(api_key=access_key, user_name=user_name)
651
  elif model_type == ModelType.StableLM:
652
  from .StableLM import StableLM_Client
653
  model = StableLM_Client(model_name, user_name=user_name)
@@ -656,30 +81,35 @@ def get_model(
656
  model = MOSS_Client(model_name, user_name=user_name)
657
  elif model_type == ModelType.YuanAI:
658
  from .inspurai import Yuan_Client
659
- model = Yuan_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
 
660
  elif model_type == ModelType.Minimax:
661
  from .minimax import MiniMax_Client
662
  if os.environ.get("MINIMAX_API_KEY") != "":
663
  access_key = os.environ.get("MINIMAX_API_KEY")
664
- model = MiniMax_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
 
665
  elif model_type == ModelType.ChuanhuAgent:
666
  from .ChuanhuAgent import ChuanhuAgent_Client
667
  model = ChuanhuAgent_Client(model_name, access_key, user_name=user_name)
668
  msg = i18n("启用的工具:") + ", ".join([i.name for i in model.tools])
669
  elif model_type == ModelType.GooglePaLM:
670
- from .Google_PaLM import Google_PaLM_Client
671
  access_key = os.environ.get("GOOGLE_PALM_API_KEY", access_key)
672
- model = Google_PaLM_Client(model_name, access_key, user_name=user_name)
 
673
  elif model_type == ModelType.LangchainChat:
674
  from .Azure import Azure_OpenAI_Client
675
  model = Azure_OpenAI_Client(model_name, user_name=user_name)
676
  elif model_type == ModelType.Midjourney:
677
  from .midjourney import Midjourney_Client
678
  mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
679
- model = Midjourney_Client(model_name, mj_proxy_api_secret, user_name=user_name)
 
680
  elif model_type == ModelType.Spark:
681
  from .spark import Spark_Client
682
- model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv("SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
 
683
  elif model_type == ModelType.Unknown:
684
  raise ValueError(f"未知模型: {model_name}")
685
  logging.info(msg)
 
1
  from __future__ import annotations
2
 
 
 
3
  import logging
4
  import os
 
 
 
 
5
 
6
  import colorama
7
  import commentjson as cjson
 
 
8
 
9
  from modules import config
10
 
 
 
11
  from ..index_func import *
12
  from ..presets import *
13
  from ..utils import *
14
  from .base_model import BaseLLMModel, ModelType
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def get_model(
18
  model_name,
19
  lora_model_path=None,
 
27
  msg = i18n("模型设置为了:") + f" {model_name}"
28
  model_type = ModelType.get_type(model_name)
29
  lora_selector_visibility = False
30
+ lora_choices = ["No LoRA"]
31
  dont_change_lora_selector = False
32
  if model_type != ModelType.OpenAI:
33
  config.local_embedding = True
 
37
  try:
38
  if model_type == ModelType.OpenAI:
39
  logging.info(f"正在加载OpenAI模型: {model_name}")
40
+ from .OpenAI import OpenAIClient
41
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
42
  model = OpenAIClient(
43
  model_name=model_name,
 
49
  )
50
  elif model_type == ModelType.ChatGLM:
51
  logging.info(f"正在加载ChatGLM模型: {model_name}")
52
+ from .ChatGLM import ChatGLM_Client
53
  model = ChatGLM_Client(model_name, user_name=user_name)
54
  elif model_type == ModelType.LLaMA and lora_model_path == "":
55
  msg = f"现在请为 {model_name} 选择LoRA模型"
56
  logging.info(msg)
57
  lora_selector_visibility = True
58
  if os.path.isdir("lora"):
59
+ lora_choices = ["No LoRA"] + get_file_names_by_pinyin("lora", filetypes=[""])
 
60
  elif model_type == ModelType.LLaMA and lora_model_path != "":
61
  logging.info(f"正在加载LLaMA模型: {model_name} + {lora_model_path}")
62
+ from .LLaMA import LLaMA_Client
63
  dont_change_lora_selector = True
64
  if lora_model_path == "No LoRA":
65
  lora_model_path = None
 
69
  model = LLaMA_Client(
70
  model_name, lora_model_path, user_name=user_name)
71
  elif model_type == ModelType.XMChat:
72
+ from .XMChat import XMChatClient
73
  if os.environ.get("XMCHAT_API_KEY") != "":
74
  access_key = os.environ.get("XMCHAT_API_KEY")
75
+ model = XMChatClient(api_key=access_key, user_name=user_name)
76
  elif model_type == ModelType.StableLM:
77
  from .StableLM import StableLM_Client
78
  model = StableLM_Client(model_name, user_name=user_name)
 
81
  model = MOSS_Client(model_name, user_name=user_name)
82
  elif model_type == ModelType.YuanAI:
83
  from .inspurai import Yuan_Client
84
+ model = Yuan_Client(model_name, api_key=access_key,
85
+ user_name=user_name, system_prompt=system_prompt)
86
  elif model_type == ModelType.Minimax:
87
  from .minimax import MiniMax_Client
88
  if os.environ.get("MINIMAX_API_KEY") != "":
89
  access_key = os.environ.get("MINIMAX_API_KEY")
90
+ model = MiniMax_Client(
91
+ model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
92
  elif model_type == ModelType.ChuanhuAgent:
93
  from .ChuanhuAgent import ChuanhuAgent_Client
94
  model = ChuanhuAgent_Client(model_name, access_key, user_name=user_name)
95
  msg = i18n("启用的工具:") + ", ".join([i.name for i in model.tools])
96
  elif model_type == ModelType.GooglePaLM:
97
+ from .GooglePaLM import Google_PaLM_Client
98
  access_key = os.environ.get("GOOGLE_PALM_API_KEY", access_key)
99
+ model = Google_PaLM_Client(
100
+ model_name, access_key, user_name=user_name)
101
  elif model_type == ModelType.LangchainChat:
102
  from .Azure import Azure_OpenAI_Client
103
  model = Azure_OpenAI_Client(model_name, user_name=user_name)
104
  elif model_type == ModelType.Midjourney:
105
  from .midjourney import Midjourney_Client
106
  mj_proxy_api_secret = os.getenv("MIDJOURNEY_PROXY_API_SECRET")
107
+ model = Midjourney_Client(
108
+ model_name, mj_proxy_api_secret, user_name=user_name)
109
  elif model_type == ModelType.Spark:
110
  from .spark import Spark_Client
111
+ model = Spark_Client(model_name, os.getenv("SPARK_APPID"), os.getenv(
112
+ "SPARK_API_KEY"), os.getenv("SPARK_API_SECRET"), user_name=user_name)
113
  elif model_type == ModelType.Unknown:
114
  raise ValueError(f"未知模型: {model_name}")
115
  logging.info(msg)
modules/presets.py CHANGED
@@ -83,12 +83,21 @@ LOCAL_MODELS = [
83
  "chatglm2-6b-int4",
84
  "StableLM",
85
  "MOSS",
86
- "llama-7b-hf",
87
- "llama-13b-hf",
88
- "llama-30b-hf",
89
- "llama-65b-hf",
90
  ]
91
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
93
  MODELS = ONLINE_MODELS
94
  else:
@@ -135,8 +144,8 @@ REPLY_LANGUAGES = [
135
  ]
136
 
137
  HISTORY_NAME_METHODS = [
138
- i18n("根据日期时间"),
139
- i18n("第一条提问"),
140
  i18n("模型自动总结(消耗tokens)"),
141
  ]
142
 
@@ -266,4 +275,3 @@ small_and_beautiful_theme = gr.themes.Soft(
266
  # gradio 会把这个几个chatbot打头的变量应用到其他md渲染的地方,鬼晓得怎么想的。。。
267
  chatbot_code_background_color_dark="*neutral_950",
268
  )
269
-
 
83
  "chatglm2-6b-int4",
84
  "StableLM",
85
  "MOSS",
86
+ "Llama-2-7B-Chat",
 
 
 
87
  ]
88
 
89
+ # Additional metadate for local models
90
+ MODEL_METADATA = {
91
+ "Llama-2-7B":{
92
+ "repo_id": "TheBloke/Llama-2-7B-GGUF",
93
+ "filelist": ["llama-2-7b.Q6_K.gguf"],
94
+ },
95
+ "Llama-2-7B-Chat":{
96
+ "repo_id": "TheBloke/Llama-2-7b-Chat-GGUF",
97
+ "filelist": ["llama-2-7b-chat.Q6_K.gguf"],
98
+ }
99
+ }
100
+
101
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
102
  MODELS = ONLINE_MODELS
103
  else:
 
144
  ]
145
 
146
  HISTORY_NAME_METHODS = [
147
+ i18n("根据日期时间"),
148
+ i18n("第一条提问"),
149
  i18n("模型自动总结(消耗tokens)"),
150
  ]
151
 
 
275
  # gradio 会把这个几个chatbot打头的变量应用到其他md渲染的地方,鬼晓得怎么想的。。。
276
  chatbot_code_background_color_dark="*neutral_950",
277
  )
 
requirements_advanced.txt CHANGED
@@ -1,11 +1,8 @@
1
  transformers
2
  huggingface_hub
3
  torch
4
- icetk
5
- protobuf==3.19.0
6
- git+https://github.com/OptimalScale/LMFlow.git
7
  cpm-kernels
8
  sentence_transformers
9
  accelerate
10
  sentencepiece
11
- datasets
 
1
  transformers
2
  huggingface_hub
3
  torch
 
 
 
4
  cpm-kernels
5
  sentence_transformers
6
  accelerate
7
  sentencepiece
8
+ llama-cpp-python
run_Linux.sh CHANGED
File without changes
run_macOS.command CHANGED
File without changes