MZhaovo commited on
Commit
5e4ca56
1 Parent(s): 1bda668

更优雅的更换自定义api方式

Browse files
ChuanhuChatbot.py CHANGED
@@ -164,12 +164,12 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
164
  )
165
 
166
  with gr.Accordion("网络设置", open=False):
167
- apiurlTxt = gr.Textbox(
168
  show_label=True,
169
- placeholder=f"在这里输入API地址...",
170
- label="API地址",
171
- value="https://api.openai.com/v1/chat/completions",
172
- lines=2,
173
  )
174
  changeAPIURLBtn = gr.Button("🔄 切换API地址")
175
  proxyTxt = gr.Textbox(
@@ -343,11 +343,11 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
343
 
344
  # Advanced
345
  default_btn.click(
346
- reset_default, [], [apiurlTxt, proxyTxt, status_display], show_progress=True
347
  )
348
  changeAPIURLBtn.click(
349
- change_api_url,
350
- [apiurlTxt],
351
  [status_display],
352
  show_progress=True,
353
  )
 
164
  )
165
 
166
  with gr.Accordion("网络设置", open=False):
167
+ apihostTxt = gr.Textbox(
168
  show_label=True,
169
+ placeholder=f"在这里输入API-Host...",
170
+ label="API-Host",
171
+ value="api.openai.com",
172
+ lines=1,
173
  )
174
  changeAPIURLBtn = gr.Button("🔄 切换API地址")
175
  proxyTxt = gr.Textbox(
 
343
 
344
  # Advanced
345
  default_btn.click(
346
+ reset_default, [], [apihostTxt, proxyTxt, status_display], show_progress=True
347
  )
348
  changeAPIURLBtn.click(
349
+ change_api_host,
350
+ [apihostTxt],
351
  [status_display],
352
  show_progress=True,
353
  )
modules/chat_func.py CHANGED
@@ -13,9 +13,7 @@ import colorama
13
  from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
16
- from llama_index.indices.query.vector_store import GPTVectorStoreIndexQuery
17
- from llama_index.indices.query.schema import QueryBundle
18
- from langchain.llms import OpenAIChat
19
 
20
  from modules.presets import *
21
  from modules.llama_func import *
@@ -63,13 +61,13 @@ def get_response(
63
  timeout = timeout_all
64
 
65
 
66
- # 如果有自定义的api-url,使用自定义url发送请求,否则使用默认设置发送请求
67
- if shared.state.api_url != API_URL:
68
- logging.info(f"使用自定义API URL: {shared.state.api_url}")
69
 
70
  with retrieve_proxy():
71
  response = requests.post(
72
- shared.state.api_url,
73
  headers=headers,
74
  json=payload,
75
  stream=True,
@@ -270,6 +268,11 @@ def predict(
270
  reply_language="中文",
271
  should_check_token_count=True,
272
  ): # repetition_penalty, top_k
 
 
 
 
 
273
  logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
274
  if should_check_token_count:
275
  yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
 
13
  from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
16
+
 
 
17
 
18
  from modules.presets import *
19
  from modules.llama_func import *
 
61
  timeout = timeout_all
62
 
63
 
64
+ # 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
65
+ if shared.state.completion_url != COMPLETION_URL:
66
+ logging.info(f"使用自定义API URL: {shared.state.completion_url}")
67
 
68
  with retrieve_proxy():
69
  response = requests.post(
70
+ shared.state.completion_url,
71
  headers=headers,
72
  json=payload,
73
  stream=True,
 
268
  reply_language="中文",
269
  should_check_token_count=True,
270
  ): # repetition_penalty, top_k
271
+ from llama_index.indices.query.vector_store import GPTVectorStoreIndexQuery
272
+ from llama_index.indices.query.schema import QueryBundle
273
+ from langchain.llms import OpenAIChat
274
+
275
+
276
  logging.info("输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL)
277
  if should_check_token_count:
278
  yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
modules/config.py CHANGED
@@ -110,4 +110,4 @@ def retrieve_proxy(proxy=None):
110
 
111
 
112
  ## 处理advance pdf
113
- advance_pdf = config.get("advance_pdf", {})
 
110
 
111
 
112
  ## 处理advance pdf
113
+ advance_pdf = config.get("advance_pdf", {})
modules/llama_func.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import logging
3
 
4
- from llama_index import GPTSimpleVectorIndex, ServiceContext
5
  from llama_index import download_loader
6
  from llama_index import (
7
  Document,
@@ -10,8 +9,6 @@ from llama_index import (
10
  QuestionAnswerPrompt,
11
  RefinePrompt,
12
  )
13
- from langchain.llms import OpenAI
14
- from langchain.chat_models import ChatOpenAI
15
  import colorama
16
  import PyPDF2
17
  from tqdm import tqdm
@@ -89,6 +86,9 @@ def construct_index(
89
  embedding_limit=None,
90
  separator=" "
91
  ):
 
 
 
92
  os.environ["OPENAI_API_KEY"] = api_key
93
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
94
  embedding_limit = None if embedding_limit == 0 else embedding_limit
@@ -120,98 +120,7 @@ def construct_index(
120
  logging.error("索引构建失败!", e)
121
  print(e)
122
  return None
123
-
124
-
125
- def chat_ai(
126
- api_key,
127
- index,
128
- question,
129
- context,
130
- chatbot,
131
- reply_language,
132
- ):
133
- os.environ["OPENAI_API_KEY"] = api_key
134
-
135
- logging.info(f"Question: {question}")
136
-
137
- response, chatbot_display, status_text = ask_ai(
138
- api_key,
139
- index,
140
- question,
141
- replace_today(PROMPT_TEMPLATE),
142
- REFINE_TEMPLATE,
143
- SIM_K,
144
- INDEX_QUERY_TEMPRATURE,
145
- context,
146
- reply_language,
147
- )
148
- if response is None:
149
- status_text = "查询失败,请换个问法试试"
150
- return context, chatbot
151
- response = response
152
-
153
- context.append({"role": "user", "content": question})
154
- context.append({"role": "assistant", "content": response})
155
- chatbot.append((question, chatbot_display))
156
-
157
- os.environ["OPENAI_API_KEY"] = ""
158
- return context, chatbot, status_text
159
-
160
-
161
- def ask_ai(
162
- api_key,
163
- index,
164
- question,
165
- prompt_tmpl,
166
- refine_tmpl,
167
- sim_k=5,
168
- temprature=0,
169
- prefix_messages=[],
170
- reply_language="中文",
171
- ):
172
- os.environ["OPENAI_API_KEY"] = api_key
173
-
174
- logging.debug("Index file found")
175
- logging.debug("Querying index...")
176
- llm_predictor = LLMPredictor(
177
- llm=ChatOpenAI(
178
- temperature=temprature,
179
- model_name="gpt-3.5-turbo-0301",
180
- prefix_messages=prefix_messages,
181
- )
182
- )
183
-
184
- response = None # Initialize response variable to avoid UnboundLocalError
185
- qa_prompt = QuestionAnswerPrompt(prompt_tmpl.replace("{reply_language}", reply_language))
186
- rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
187
- response = index.query(
188
- question,
189
- similarity_top_k=sim_k,
190
- text_qa_template=qa_prompt,
191
- refine_template=rf_prompt,
192
- response_mode="compact",
193
- )
194
-
195
- if response is not None:
196
- logging.info(f"Response: {response}")
197
- ret_text = response.response
198
- nodes = []
199
- for index, node in enumerate(response.source_nodes):
200
- brief = node.source_text[:25].replace("\n", "")
201
- nodes.append(
202
- f"<details><summary>[{index + 1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
203
- )
204
- new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
205
- logging.info(
206
- f"Response: {colorama.Fore.BLUE}{ret_text}{colorama.Style.RESET_ALL}"
207
- )
208
- os.environ["OPENAI_API_KEY"] = ""
209
- return ret_text, new_response, f"查询消耗了{llm_predictor.last_token_usage} tokens"
210
- else:
211
- logging.warning("No response found, returning None")
212
- os.environ["OPENAI_API_KEY"] = ""
213
- return None
214
-
215
 
216
  def add_space(text):
217
  punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
 
1
  import os
2
  import logging
3
 
 
4
  from llama_index import download_loader
5
  from llama_index import (
6
  Document,
 
9
  QuestionAnswerPrompt,
10
  RefinePrompt,
11
  )
 
 
12
  import colorama
13
  import PyPDF2
14
  from tqdm import tqdm
 
86
  embedding_limit=None,
87
  separator=" "
88
  ):
89
+ from langchain.chat_models import ChatOpenAI
90
+ from llama_index import GPTSimpleVectorIndex, ServiceContext
91
+
92
  os.environ["OPENAI_API_KEY"] = api_key
93
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
94
  embedding_limit = None if embedding_limit == 0 else embedding_limit
 
120
  logging.error("索引构建失败!", e)
121
  print(e)
122
  return None
123
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def add_space(text):
126
  punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
modules/openai_func.py CHANGED
@@ -37,7 +37,7 @@ def get_billing_data(openai_api_key, billing_url):
37
 
38
  def get_usage(openai_api_key):
39
  try:
40
- balance_data=get_billing_data(openai_api_key, BALANCE_API_URL)
41
  logging.debug(balance_data)
42
  try:
43
  balance = balance_data["total_available"] if balance_data["total_available"] else 0
@@ -51,7 +51,7 @@ def get_usage(openai_api_key):
51
  if balance == 0:
52
  last_day_of_month = datetime.datetime.now().strftime("%Y-%m-%d")
53
  first_day_of_month = datetime.datetime.now().replace(day=1).strftime("%Y-%m-%d")
54
- usage_url = f"{USAGE_API_URL}?start_date={first_day_of_month}&end_date={last_day_of_month}"
55
  try:
56
  usage_data = get_billing_data(openai_api_key, usage_url)
57
  except Exception as e:
 
37
 
38
  def get_usage(openai_api_key):
39
  try:
40
+ balance_data=get_billing_data(openai_api_key, shared.state.balance_api_url)
41
  logging.debug(balance_data)
42
  try:
43
  balance = balance_data["total_available"] if balance_data["total_available"] else 0
 
51
  if balance == 0:
52
  last_day_of_month = datetime.datetime.now().strftime("%Y-%m-%d")
53
  first_day_of_month = datetime.datetime.now().replace(day=1).strftime("%Y-%m-%d")
54
+ usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
55
  try:
56
  usage_data = get_billing_data(openai_api_key, usage_url)
57
  except Exception as e:
modules/presets.py CHANGED
@@ -3,7 +3,8 @@ import gradio as gr
3
 
4
  # ChatGPT 设置
5
  initial_prompt = "You are a helpful assistant."
6
- API_URL = "https://api.openai.com/v1/chat/completions"
 
7
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
8
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
9
  HISTORY_DIR = "history"
 
3
 
4
  # ChatGPT 设置
5
  initial_prompt = "You are a helpful assistant."
6
+ API_HOST = "api.openai.com"
7
+ COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
8
  BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
9
  USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
10
  HISTORY_DIR = "history"
modules/shared.py CHANGED
@@ -1,8 +1,10 @@
1
- from modules.presets import API_URL
2
-
3
  class State:
4
  interrupted = False
5
- api_url = API_URL
 
 
6
 
7
  def interrupt(self):
8
  self.interrupted = True
@@ -10,15 +12,21 @@ class State:
10
  def recover(self):
11
  self.interrupted = False
12
 
13
- def set_api_url(self, api_url):
14
- self.api_url = api_url
 
 
 
15
 
16
- def reset_api_url(self):
17
- self.api_url = API_URL
18
- return self.api_url
 
 
 
19
 
20
  def reset_all(self):
21
  self.interrupted = False
22
- self.api_url = API_URL
23
 
24
  state = State()
 
1
+ from modules.presets import COMPLETION_URL, BALANCE_API_URL, USAGE_API_URL, API_HOST
2
+ import os
3
  class State:
4
  interrupted = False
5
+ completion_url = COMPLETION_URL
6
+ balance_api_url = BALANCE_API_URL
7
+ usage_api_url = USAGE_API_URL
8
 
9
  def interrupt(self):
10
  self.interrupted = True
 
12
  def recover(self):
13
  self.interrupted = False
14
 
15
+ def set_api_host(self, api_host):
16
+ self.completion_url = f"https://{api_host}/v1/chat/completions"
17
+ self.balance_api_url = f"https://{api_host}/dashboard/billing/credit_grants"
18
+ self.usage_api_url = f"https://{api_host}/dashboard/billing/usage"
19
+ os.environ["OPENAI_API_BASE"] = f"https://{api_host}/v1"
20
 
21
+ def reset_api_host(self):
22
+ self.completion_url = COMPLETION_URL
23
+ self.balance_api_url = BALANCE_API_URL
24
+ self.usage_api_url = USAGE_API_URL
25
+ os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}/v1"
26
+ return API_HOST
27
 
28
  def reset_all(self):
29
  self.interrupted = False
30
+ self.completion_url = COMPLETION_URL
31
 
32
  state = State()
modules/utils.py CHANGED
@@ -328,14 +328,14 @@ def reset_textbox():
328
 
329
 
330
  def reset_default():
331
- newurl = shared.state.reset_api_url()
332
  retrieve_proxy("")
333
- return gr.update(value=newurl), gr.update(value=""), "API URL 和代理已重置"
334
 
335
 
336
- def change_api_url(url):
337
- shared.state.set_api_url(url)
338
- msg = f"API地址更改为了{url}"
339
  logging.info(msg)
340
  return msg
341
 
 
328
 
329
 
330
  def reset_default():
331
+ default_host = shared.state.reset_api_host()
332
  retrieve_proxy("")
333
+ return gr.update(value=default_host), gr.update(value=""), "API-Host 和代理已重置"
334
 
335
 
336
+ def change_api_host(host):
337
+ shared.state.set_api_host(host)
338
+ msg = f"API-Host更改为了{host}"
339
  logging.info(msg)
340
  return msg
341