Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
•
8c04739
1
Parent(s):
4b9ef74
feat: Azure OpenAI API 支持 embedding
Browse files- config_example.json +5 -2
- modules/config.py +39 -25
- modules/index_func.py +12 -4
- modules/models/azure.py +1 -1
config_example.json
CHANGED
@@ -9,10 +9,13 @@
|
|
9 |
"minimax_group_id": "", // 你的 MiniMax Group ID,用于 MiniMax 对话模型
|
10 |
|
11 |
//== Azure ==
|
|
|
12 |
"azure_openai_api_key": "", // 你的 Azure OpenAI API Key,用于 Azure OpenAI 对话模型
|
13 |
-
"
|
14 |
"azure_openai_api_version": "2023-05-15", // 你的 Azure OpenAI API 版本
|
15 |
-
"azure_deployment_name": "", // 你的 Azure
|
|
|
|
|
16 |
|
17 |
//== 基础配置 ==
|
18 |
"language": "auto", // 界面语言,可选"auto", "zh-CN", "en-US", "ja-JP", "ko-KR"
|
|
|
9 |
"minimax_group_id": "", // 你的 MiniMax Group ID,用于 MiniMax 对话模型
|
10 |
|
11 |
//== Azure ==
|
12 |
+
"openai_api_type": "openai", // 可选项:azure, openai
|
13 |
"azure_openai_api_key": "", // 你的 Azure OpenAI API Key,用于 Azure OpenAI 对话模型
|
14 |
+
"azure_openai_api_base_url": "", // 你的 Azure Base URL
|
15 |
"azure_openai_api_version": "2023-05-15", // 你的 Azure OpenAI API 版本
|
16 |
+
"azure_deployment_name": "", // 你的 Azure OpenAI Chat 模型 Deployment 名称
|
17 |
+
"azure_embedding_deployment_name": "", // 你的 Azure OpenAI Embedding 模型 Deployment 名称
|
18 |
+
"azure_embedding_model_name": "text-embedding-ada-002", // 你的 Azure OpenAI Embedding 模型名称
|
19 |
|
20 |
//== 基础配置 ==
|
21 |
"language": "auto", // 界面语言,可选"auto", "zh-CN", "en-US", "ja-JP", "ko-KR"
|
modules/config.py
CHANGED
@@ -39,19 +39,22 @@ if os.path.exists("config.json"):
|
|
39 |
else:
|
40 |
config = {}
|
41 |
|
|
|
42 |
def load_config_to_environ(key_list):
|
43 |
global config
|
44 |
for key in key_list:
|
45 |
if key in config:
|
46 |
os.environ[key.upper()] = os.environ.get(key.upper(), config[key])
|
47 |
|
|
|
48 |
sensitive_id = config.get("sensitive_id", "")
|
49 |
sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
|
50 |
|
51 |
lang_config = config.get("language", "auto")
|
52 |
language = os.environ.get("LANGUAGE", lang_config)
|
53 |
|
54 |
-
hide_history_when_not_logged_in = config.get(
|
|
|
55 |
check_update = config.get("check_update", True)
|
56 |
show_api_billing = config.get("show_api_billing", False)
|
57 |
show_api_billing = bool(os.environ.get("SHOW_API_BILLING", show_api_billing))
|
@@ -68,31 +71,32 @@ if os.path.exists("auth.json"):
|
|
68 |
logging.info("检测到auth.json文件,正在进行迁移...")
|
69 |
auth_list = []
|
70 |
with open("auth.json", "r", encoding='utf-8') as f:
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
config["users"] = auth_list
|
79 |
os.rename("auth.json", "auth(deprecated).json")
|
80 |
with open("config.json", "w", encoding='utf-8') as f:
|
81 |
json.dump(config, f, indent=4, ensure_ascii=False)
|
82 |
|
83 |
-
|
84 |
dockerflag = config.get("dockerflag", False)
|
85 |
if os.environ.get("dockerrun") == "yes":
|
86 |
dockerflag = True
|
87 |
|
88 |
-
|
89 |
my_api_key = config.get("openai_api_key", "")
|
90 |
my_api_key = os.environ.get("OPENAI_API_KEY", my_api_key)
|
91 |
os.environ["OPENAI_API_KEY"] = my_api_key
|
92 |
os.environ["OPENAI_EMBEDDING_API_KEY"] = my_api_key
|
93 |
|
94 |
google_palm_api_key = config.get("google_palm_api_key", "")
|
95 |
-
google_palm_api_key = os.environ.get(
|
|
|
96 |
os.environ["GOOGLE_PALM_API_KEY"] = google_palm_api_key
|
97 |
|
98 |
xmchat_api_key = config.get("xmchat_api_key", "")
|
@@ -103,13 +107,14 @@ os.environ["MINIMAX_API_KEY"] = minimax_api_key
|
|
103 |
minimax_group_id = config.get("minimax_group_id", "")
|
104 |
os.environ["MINIMAX_GROUP_ID"] = minimax_group_id
|
105 |
|
106 |
-
load_config_to_environ(["
|
|
|
107 |
|
108 |
|
109 |
usage_limit = os.environ.get("USAGE_LIMIT", config.get("usage_limit", 120))
|
110 |
|
111 |
-
|
112 |
-
multi_api_key = config.get("multi_api_key", False)
|
113 |
if multi_api_key:
|
114 |
api_key_list = config.get("api_key_list", [])
|
115 |
if len(api_key_list) == 0:
|
@@ -117,23 +122,26 @@ if multi_api_key:
|
|
117 |
sys.exit(1)
|
118 |
shared.state.set_api_key_queue(api_key_list)
|
119 |
|
120 |
-
auth_list = config.get("users", [])
|
121 |
authflag = len(auth_list) > 0 # 是否开启认证的状态值,改为判断auth_list长度
|
122 |
|
123 |
# 处理自定义的api_host,优先读环境变量的配置,如果存在则自动装配
|
124 |
-
api_host = os.environ.get(
|
|
|
125 |
if api_host is not None:
|
126 |
shared.state.set_api_host(api_host)
|
127 |
os.environ["OPENAI_API_BASE"] = f"{api_host}/v1"
|
128 |
logging.info(f"OpenAI API Base set to: {os.environ['OPENAI_API_BASE']}")
|
129 |
|
130 |
-
default_chuanhu_assistant_model = config.get(
|
|
|
131 |
for x in ["GOOGLE_CSE_ID", "GOOGLE_API_KEY", "WOLFRAM_ALPHA_APPID", "SERPAPI_API_KEY"]:
|
132 |
if config.get(x, None) is not None:
|
133 |
os.environ[x] = config[x]
|
134 |
|
|
|
135 |
@contextmanager
|
136 |
-
def retrieve_openai_api(api_key
|
137 |
old_api_key = os.environ.get("OPENAI_API_KEY", "")
|
138 |
if api_key is None:
|
139 |
os.environ["OPENAI_API_KEY"] = my_api_key
|
@@ -143,14 +151,15 @@ def retrieve_openai_api(api_key = None):
|
|
143 |
yield api_key
|
144 |
os.environ["OPENAI_API_KEY"] = old_api_key
|
145 |
|
146 |
-
|
|
|
147 |
log_level = config.get("log_level", "INFO")
|
148 |
logging.basicConfig(
|
149 |
level=log_level,
|
150 |
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
151 |
)
|
152 |
|
153 |
-
|
154 |
http_proxy = os.environ.get("HTTP_PROXY", "")
|
155 |
https_proxy = os.environ.get("HTTPS_PROXY", "")
|
156 |
http_proxy = config.get("http_proxy", http_proxy)
|
@@ -160,7 +169,8 @@ https_proxy = config.get("https_proxy", https_proxy)
|
|
160 |
os.environ["HTTP_PROXY"] = ""
|
161 |
os.environ["HTTPS_PROXY"] = ""
|
162 |
|
163 |
-
local_embedding = config.get("local_embedding", False)
|
|
|
164 |
|
165 |
@contextmanager
|
166 |
def retrieve_proxy(proxy=None):
|
@@ -177,12 +187,13 @@ def retrieve_proxy(proxy=None):
|
|
177 |
old_var = os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"]
|
178 |
os.environ["HTTP_PROXY"] = http_proxy
|
179 |
os.environ["HTTPS_PROXY"] = https_proxy
|
180 |
-
yield http_proxy, https_proxy
|
181 |
|
182 |
# return old proxy
|
183 |
os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
|
184 |
|
185 |
-
|
|
|
186 |
user_latex_option = config.get("latex_option", "default")
|
187 |
if user_latex_option == "default":
|
188 |
latex_delimiters_set = [
|
@@ -219,16 +230,19 @@ else:
|
|
219 |
{"left": "\\[", "right": "\\]", "display": True},
|
220 |
]
|
221 |
|
222 |
-
|
223 |
advance_docs = defaultdict(lambda: defaultdict(dict))
|
224 |
advance_docs.update(config.get("advance_docs", {}))
|
|
|
|
|
225 |
def update_doc_config(two_column_pdf):
|
226 |
global advance_docs
|
227 |
advance_docs["pdf"]["two_column"] = two_column_pdf
|
228 |
|
229 |
logging.info(f"更新后的文件参数为:{advance_docs}")
|
230 |
|
231 |
-
|
|
|
232 |
server_name = config.get("server_name", None)
|
233 |
server_port = config.get("server_port", None)
|
234 |
if server_name is None:
|
|
|
39 |
else:
|
40 |
config = {}
|
41 |
|
42 |
+
|
43 |
def load_config_to_environ(key_list):
|
44 |
global config
|
45 |
for key in key_list:
|
46 |
if key in config:
|
47 |
os.environ[key.upper()] = os.environ.get(key.upper(), config[key])
|
48 |
|
49 |
+
|
50 |
sensitive_id = config.get("sensitive_id", "")
|
51 |
sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
|
52 |
|
53 |
lang_config = config.get("language", "auto")
|
54 |
language = os.environ.get("LANGUAGE", lang_config)
|
55 |
|
56 |
+
hide_history_when_not_logged_in = config.get(
|
57 |
+
"hide_history_when_not_logged_in", False)
|
58 |
check_update = config.get("check_update", True)
|
59 |
show_api_billing = config.get("show_api_billing", False)
|
60 |
show_api_billing = bool(os.environ.get("SHOW_API_BILLING", show_api_billing))
|
|
|
71 |
logging.info("检测到auth.json文件,正在进行迁移...")
|
72 |
auth_list = []
|
73 |
with open("auth.json", "r", encoding='utf-8') as f:
|
74 |
+
auth = json.load(f)
|
75 |
+
for _ in auth:
|
76 |
+
if auth[_]["username"] and auth[_]["password"]:
|
77 |
+
auth_list.append((auth[_]["username"], auth[_]["password"]))
|
78 |
+
else:
|
79 |
+
logging.error("请检查auth.json文件中的用户名和密码!")
|
80 |
+
sys.exit(1)
|
81 |
config["users"] = auth_list
|
82 |
os.rename("auth.json", "auth(deprecated).json")
|
83 |
with open("config.json", "w", encoding='utf-8') as f:
|
84 |
json.dump(config, f, indent=4, ensure_ascii=False)
|
85 |
|
86 |
+
# 处理docker if we are running in Docker
|
87 |
dockerflag = config.get("dockerflag", False)
|
88 |
if os.environ.get("dockerrun") == "yes":
|
89 |
dockerflag = True
|
90 |
|
91 |
+
# 处理 api-key 以及 允许的用户列表
|
92 |
my_api_key = config.get("openai_api_key", "")
|
93 |
my_api_key = os.environ.get("OPENAI_API_KEY", my_api_key)
|
94 |
os.environ["OPENAI_API_KEY"] = my_api_key
|
95 |
os.environ["OPENAI_EMBEDDING_API_KEY"] = my_api_key
|
96 |
|
97 |
google_palm_api_key = config.get("google_palm_api_key", "")
|
98 |
+
google_palm_api_key = os.environ.get(
|
99 |
+
"GOOGLE_PALM_API_KEY", google_palm_api_key)
|
100 |
os.environ["GOOGLE_PALM_API_KEY"] = google_palm_api_key
|
101 |
|
102 |
xmchat_api_key = config.get("xmchat_api_key", "")
|
|
|
107 |
minimax_group_id = config.get("minimax_group_id", "")
|
108 |
os.environ["MINIMAX_GROUP_ID"] = minimax_group_id
|
109 |
|
110 |
+
load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
|
111 |
+
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
|
112 |
|
113 |
|
114 |
usage_limit = os.environ.get("USAGE_LIMIT", config.get("usage_limit", 120))
|
115 |
|
116 |
+
# 多账户机制
|
117 |
+
multi_api_key = config.get("multi_api_key", False) # 是否开启多账户机制
|
118 |
if multi_api_key:
|
119 |
api_key_list = config.get("api_key_list", [])
|
120 |
if len(api_key_list) == 0:
|
|
|
122 |
sys.exit(1)
|
123 |
shared.state.set_api_key_queue(api_key_list)
|
124 |
|
125 |
+
auth_list = config.get("users", []) # 实际上是使用者的列表
|
126 |
authflag = len(auth_list) > 0 # 是否开启认证的状态值,改为判断auth_list长度
|
127 |
|
128 |
# 处理自定义的api_host,优先读环境变量的配置,如果存在则自动装配
|
129 |
+
api_host = os.environ.get(
|
130 |
+
"OPENAI_API_BASE", config.get("openai_api_base", None))
|
131 |
if api_host is not None:
|
132 |
shared.state.set_api_host(api_host)
|
133 |
os.environ["OPENAI_API_BASE"] = f"{api_host}/v1"
|
134 |
logging.info(f"OpenAI API Base set to: {os.environ['OPENAI_API_BASE']}")
|
135 |
|
136 |
+
default_chuanhu_assistant_model = config.get(
|
137 |
+
"default_chuanhu_assistant_model", "gpt-3.5-turbo")
|
138 |
for x in ["GOOGLE_CSE_ID", "GOOGLE_API_KEY", "WOLFRAM_ALPHA_APPID", "SERPAPI_API_KEY"]:
|
139 |
if config.get(x, None) is not None:
|
140 |
os.environ[x] = config[x]
|
141 |
|
142 |
+
|
143 |
@contextmanager
|
144 |
+
def retrieve_openai_api(api_key=None):
|
145 |
old_api_key = os.environ.get("OPENAI_API_KEY", "")
|
146 |
if api_key is None:
|
147 |
os.environ["OPENAI_API_KEY"] = my_api_key
|
|
|
151 |
yield api_key
|
152 |
os.environ["OPENAI_API_KEY"] = old_api_key
|
153 |
|
154 |
+
|
155 |
+
# 处理log
|
156 |
log_level = config.get("log_level", "INFO")
|
157 |
logging.basicConfig(
|
158 |
level=log_level,
|
159 |
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
160 |
)
|
161 |
|
162 |
+
# 处理代理:
|
163 |
http_proxy = os.environ.get("HTTP_PROXY", "")
|
164 |
https_proxy = os.environ.get("HTTPS_PROXY", "")
|
165 |
http_proxy = config.get("http_proxy", http_proxy)
|
|
|
169 |
os.environ["HTTP_PROXY"] = ""
|
170 |
os.environ["HTTPS_PROXY"] = ""
|
171 |
|
172 |
+
local_embedding = config.get("local_embedding", False) # 是否使用本地embedding
|
173 |
+
|
174 |
|
175 |
@contextmanager
|
176 |
def retrieve_proxy(proxy=None):
|
|
|
187 |
old_var = os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"]
|
188 |
os.environ["HTTP_PROXY"] = http_proxy
|
189 |
os.environ["HTTPS_PROXY"] = https_proxy
|
190 |
+
yield http_proxy, https_proxy # return new proxy
|
191 |
|
192 |
# return old proxy
|
193 |
os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
|
194 |
|
195 |
+
|
196 |
+
# 处理latex options
|
197 |
user_latex_option = config.get("latex_option", "default")
|
198 |
if user_latex_option == "default":
|
199 |
latex_delimiters_set = [
|
|
|
230 |
{"left": "\\[", "right": "\\]", "display": True},
|
231 |
]
|
232 |
|
233 |
+
# 处理advance docs
|
234 |
advance_docs = defaultdict(lambda: defaultdict(dict))
|
235 |
advance_docs.update(config.get("advance_docs", {}))
|
236 |
+
|
237 |
+
|
238 |
def update_doc_config(two_column_pdf):
|
239 |
global advance_docs
|
240 |
advance_docs["pdf"]["two_column"] = two_column_pdf
|
241 |
|
242 |
logging.info(f"更新后的文件参数为:{advance_docs}")
|
243 |
|
244 |
+
|
245 |
+
# 处理gradio.launch参数
|
246 |
server_name = config.get("server_name", None)
|
247 |
server_port = config.get("server_port", None)
|
248 |
if server_name is None:
|
modules/index_func.py
CHANGED
@@ -51,7 +51,8 @@ def get_documents(file_src):
|
|
51 |
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
52 |
for page in tqdm(pdfReader.pages):
|
53 |
pdftext += page.extract_text()
|
54 |
-
texts = [Document(page_content=pdftext,
|
|
|
55 |
elif file_type == ".docx":
|
56 |
logging.debug("Loading Word...")
|
57 |
from langchain.document_loaders import UnstructuredWordDocumentLoader
|
@@ -72,7 +73,8 @@ def get_documents(file_src):
|
|
72 |
text_list = excel_to_string(filepath)
|
73 |
texts = []
|
74 |
for elem in text_list:
|
75 |
-
texts.append(Document(page_content=elem,
|
|
|
76 |
else:
|
77 |
logging.debug("Loading text file...")
|
78 |
from langchain.document_loaders import TextLoader
|
@@ -115,10 +117,16 @@ def construct_index(
|
|
115 |
index_path = f"./index/{index_name}"
|
116 |
if local_embedding:
|
117 |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
118 |
-
embeddings = HuggingFaceEmbeddings(
|
|
|
119 |
else:
|
120 |
from langchain.embeddings import OpenAIEmbeddings
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
122 |
if os.path.exists(index_path):
|
123 |
logging.info("找到了缓存的索引文件,加载中……")
|
124 |
return FAISS.load_local(index_path, embeddings)
|
|
|
51 |
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
52 |
for page in tqdm(pdfReader.pages):
|
53 |
pdftext += page.extract_text()
|
54 |
+
texts = [Document(page_content=pdftext,
|
55 |
+
metadata={"source": filepath})]
|
56 |
elif file_type == ".docx":
|
57 |
logging.debug("Loading Word...")
|
58 |
from langchain.document_loaders import UnstructuredWordDocumentLoader
|
|
|
73 |
text_list = excel_to_string(filepath)
|
74 |
texts = []
|
75 |
for elem in text_list:
|
76 |
+
texts.append(Document(page_content=elem,
|
77 |
+
metadata={"source": filepath}))
|
78 |
else:
|
79 |
logging.debug("Loading text file...")
|
80 |
from langchain.document_loaders import TextLoader
|
|
|
117 |
index_path = f"./index/{index_name}"
|
118 |
if local_embedding:
|
119 |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
120 |
+
embeddings = HuggingFaceEmbeddings(
|
121 |
+
model_name="sentence-transformers/distiluse-base-multilingual-cased-v2")
|
122 |
else:
|
123 |
from langchain.embeddings import OpenAIEmbeddings
|
124 |
+
if os.environ.get("OPENAI_API_TYPE", "openai") == "openai":
|
125 |
+
embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get(
|
126 |
+
"OPENAI_API_BASE", None), openai_api_key=os.environ.get("OPENAI_EMBEDDING_API_KEY", api_key))
|
127 |
+
else:
|
128 |
+
embeddings = OpenAIEmbeddings(deployment=os.environ["AZURE_EMBEDDING_DEPLOYMENT_NAME"], openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
129 |
+
model=os.environ["AZURE_EMBEDDING_MODEL_NAME"], openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"], openai_api_type="azure")
|
130 |
if os.path.exists(index_path):
|
131 |
logging.info("找到了缓存的索引文件,加载中……")
|
132 |
return FAISS.load_local(index_path, embeddings)
|
modules/models/azure.py
CHANGED
@@ -9,7 +9,7 @@ class Azure_OpenAI_Client(Base_Chat_Langchain_Client):
|
|
9 |
def setup_model(self):
|
10 |
# inplement this to setup the model then return it
|
11 |
return AzureChatOpenAI(
|
12 |
-
openai_api_base=os.environ["
|
13 |
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
14 |
deployment_name=os.environ["AZURE_DEPLOYMENT_NAME"],
|
15 |
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
|
|
9 |
def setup_model(self):
|
10 |
# inplement this to setup the model then return it
|
11 |
return AzureChatOpenAI(
|
12 |
+
openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
|
13 |
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
14 |
deployment_name=os.environ["AZURE_DEPLOYMENT_NAME"],
|
15 |
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|