Tuchuanhuhuhu commited on
Commit
c12b724
1 Parent(s): 60fe470

加入索引模式的实时回答功能;适配llama_index 0.5.0;加入繁体中文支持

Browse files
modules/chat_func.py CHANGED
@@ -13,6 +13,9 @@ import colorama
13
  from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
 
 
 
16
 
17
  from modules.presets import *
18
  from modules.llama_func import *
@@ -63,7 +66,7 @@ def get_response(
63
  # 如果有自定义的api-url,使用自定义url发送请求,否则使用默认设置发送请求
64
  if shared.state.api_url != API_URL:
65
  logging.info(f"使用自定义API URL: {shared.state.api_url}")
66
-
67
  response = requests.post(
68
  shared.state.api_url,
69
  headers=headers,
@@ -72,7 +75,7 @@ def get_response(
72
  timeout=timeout,
73
  proxies=proxies,
74
  )
75
-
76
  return response
77
 
78
 
@@ -103,13 +106,17 @@ def stream_predict(
103
  else:
104
  chatbot.append((inputs, ""))
105
  user_token_count = 0
 
 
 
 
106
  if len(all_token_counts) == 0:
107
  system_prompt_token_count = count_token(construct_system(system_prompt))
108
  user_token_count = (
109
- count_token(construct_user(inputs)) + system_prompt_token_count
110
  )
111
  else:
112
- user_token_count = count_token(construct_user(inputs))
113
  all_token_counts.append(user_token_count)
114
  logging.info(f"输入token计数: {user_token_count}")
115
  yield get_return_value()
@@ -137,6 +144,8 @@ def stream_predict(
137
  yield get_return_value()
138
  error_json_str = ""
139
 
 
 
140
  for chunk in tqdm(response.iter_lines()):
141
  if counter == 0:
142
  counter += 1
@@ -201,7 +210,10 @@ def predict_all(
201
  chatbot.append((fake_input, ""))
202
  else:
203
  chatbot.append((inputs, ""))
204
- all_token_counts.append(count_token(construct_user(inputs)))
 
 
 
205
  try:
206
  response = get_response(
207
  openai_api_key,
@@ -224,13 +236,22 @@ def predict_all(
224
  status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
225
  return chatbot, history, status_text, all_token_counts
226
  response = json.loads(response.text)
227
- content = response["choices"][0]["message"]["content"]
228
- history[-1] = construct_assistant(content)
229
- chatbot[-1] = (chatbot[-1][0], content+display_append)
230
- total_token_count = response["usage"]["total_tokens"]
231
- all_token_counts[-1] = total_token_count - sum(all_token_counts)
232
- status_text = construct_token_message(total_token_count)
233
- return chatbot, history, status_text, all_token_counts
 
 
 
 
 
 
 
 
 
234
 
235
 
236
  def predict(
@@ -254,37 +275,55 @@ def predict(
254
  yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
255
  if reply_language == "跟随问题语言(不稳定)":
256
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
 
 
 
257
  if files:
 
 
258
  msg = "加载索引中……(这可能需要几分钟)"
259
  logging.info(msg)
260
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
261
  index = construct_index(openai_api_key, file_src=files)
262
  msg = "索引构建完成,获取回答中……"
 
263
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
264
- history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot, reply_language)
265
- yield chatbot, history, status_text, all_token_counts
266
- return
267
-
268
- old_inputs = ""
269
- link_references = []
270
- if use_websearch:
 
 
 
 
 
 
 
 
 
 
 
271
  search_results = ddg(inputs, max_results=5)
272
  old_inputs = inputs
273
- web_results = []
274
  for idx, result in enumerate(search_results):
275
  logging.info(f"搜索结果{idx + 1}:{result}")
276
  domain_name = urllib3.util.parse_url(result["href"]).host
277
- web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
278
- link_references.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
279
- link_references = "\n\n" + "".join(link_references)
 
280
  inputs = (
281
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
282
  .replace("{query}", inputs)
283
- .replace("{web_results}", "\n\n".join(web_results))
284
  .replace("{reply_language}", reply_language )
285
  )
286
  else:
287
- link_references = ""
288
 
289
  if len(openai_api_key) != 51:
290
  status_text = standard_error_msg + no_apikey_msg
@@ -317,7 +356,7 @@ def predict(
317
  temperature,
318
  selected_model,
319
  fake_input=old_inputs,
320
- display_append=link_references
321
  )
322
  for chatbot, history, status_text, all_token_counts in iter:
323
  if shared.state.interrupted:
@@ -337,7 +376,7 @@ def predict(
337
  temperature,
338
  selected_model,
339
  fake_input=old_inputs,
340
- display_append=link_references
341
  )
342
  yield chatbot, history, status_text, all_token_counts
343
 
@@ -350,6 +389,11 @@ def predict(
350
  + colorama.Style.RESET_ALL
351
  )
352
 
 
 
 
 
 
353
  if stream:
354
  max_token = max_token_streaming
355
  else:
 
13
  from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
16
+ from llama_index.indices.query.vector_store import GPTVectorStoreIndexQuery
17
+ from llama_index.indices.query.schema import QueryBundle
18
+ from langchain.llms import OpenAIChat
19
 
20
  from modules.presets import *
21
  from modules.llama_func import *
 
66
  # 如果有自定义的api-url,使用自定义url发送请求,否则使用默认设置发送请求
67
  if shared.state.api_url != API_URL:
68
  logging.info(f"使用自定义API URL: {shared.state.api_url}")
69
+
70
  response = requests.post(
71
  shared.state.api_url,
72
  headers=headers,
 
75
  timeout=timeout,
76
  proxies=proxies,
77
  )
78
+
79
  return response
80
 
81
 
 
106
  else:
107
  chatbot.append((inputs, ""))
108
  user_token_count = 0
109
+ if fake_input is not None:
110
+ input_token_count = count_token(construct_user(fake_input))
111
+ else:
112
+ input_token_count = count_token(construct_user(inputs))
113
  if len(all_token_counts) == 0:
114
  system_prompt_token_count = count_token(construct_system(system_prompt))
115
  user_token_count = (
116
+ input_token_count + system_prompt_token_count
117
  )
118
  else:
119
+ user_token_count = input_token_count
120
  all_token_counts.append(user_token_count)
121
  logging.info(f"输入token计数: {user_token_count}")
122
  yield get_return_value()
 
144
  yield get_return_value()
145
  error_json_str = ""
146
 
147
+ if fake_input is not None:
148
+ history[-2] = construct_user(fake_input)
149
  for chunk in tqdm(response.iter_lines()):
150
  if counter == 0:
151
  counter += 1
 
210
  chatbot.append((fake_input, ""))
211
  else:
212
  chatbot.append((inputs, ""))
213
+ if fake_input is not None:
214
+ all_token_counts.append(count_token(construct_user(fake_input)))
215
+ else:
216
+ all_token_counts.append(count_token(construct_user(inputs)))
217
  try:
218
  response = get_response(
219
  openai_api_key,
 
236
  status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
237
  return chatbot, history, status_text, all_token_counts
238
  response = json.loads(response.text)
239
+ if fake_input is not None:
240
+ history[-2] = construct_user(fake_input)
241
+ try:
242
+ content = response["choices"][0]["message"]["content"]
243
+ history[-1] = construct_assistant(content)
244
+ chatbot[-1] = (chatbot[-1][0], content+display_append)
245
+ total_token_count = response["usage"]["total_tokens"]
246
+ if fake_input is not None:
247
+ all_token_counts[-1] += count_token(construct_assistant(content))
248
+ else:
249
+ all_token_counts[-1] = total_token_count - sum(all_token_counts)
250
+ status_text = construct_token_message(total_token_count)
251
+ return chatbot, history, status_text, all_token_counts
252
+ except KeyError:
253
+ status_text = standard_error_msg + str(response)
254
+ return chatbot, history, status_text, all_token_counts
255
 
256
 
257
  def predict(
 
275
  yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
276
  if reply_language == "跟随问题语言(不稳定)":
277
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
278
+ old_inputs = None
279
+ display_reference = []
280
+ limited_context = False
281
  if files:
282
+ limited_context = True
283
+ old_inputs = inputs
284
  msg = "加载索引中……(这可能需要几分钟)"
285
  logging.info(msg)
286
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
287
  index = construct_index(openai_api_key, file_src=files)
288
  msg = "索引构建完成,获取回答中……"
289
+ logging.info(msg)
290
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
291
+ llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
292
+ prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
293
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
294
+ query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
295
+ query_bundle = QueryBundle(inputs)
296
+ nodes = query_object.retrieve(query_bundle)
297
+ reference_results = [n.node.text for n in nodes]
298
+ reference_results = add_source_numbers(reference_results, use_source=False)
299
+ display_reference = add_details(reference_results)
300
+ display_reference = "\n\n" + "".join(display_reference)
301
+ inputs = (
302
+ replace_today(PROMPT_TEMPLATE)
303
+ .replace("{query_str}", inputs)
304
+ .replace("{context_str}", "\n\n".join(reference_results))
305
+ .replace("{reply_language}", reply_language )
306
+ )
307
+ elif use_websearch:
308
+ limited_context = True
309
  search_results = ddg(inputs, max_results=5)
310
  old_inputs = inputs
311
+ reference_results = []
312
  for idx, result in enumerate(search_results):
313
  logging.info(f"搜索结果{idx + 1}:{result}")
314
  domain_name = urllib3.util.parse_url(result["href"]).host
315
+ reference_results.append([result["body"], result["href"]])
316
+ display_reference.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
317
+ reference_results = add_source_numbers(reference_results)
318
+ display_reference = "\n\n" + "".join(display_reference)
319
  inputs = (
320
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
321
  .replace("{query}", inputs)
322
+ .replace("{web_results}", "\n\n".join(reference_results))
323
  .replace("{reply_language}", reply_language )
324
  )
325
  else:
326
+ display_reference = ""
327
 
328
  if len(openai_api_key) != 51:
329
  status_text = standard_error_msg + no_apikey_msg
 
356
  temperature,
357
  selected_model,
358
  fake_input=old_inputs,
359
+ display_append=display_reference
360
  )
361
  for chatbot, history, status_text, all_token_counts in iter:
362
  if shared.state.interrupted:
 
376
  temperature,
377
  selected_model,
378
  fake_input=old_inputs,
379
+ display_append=display_reference
380
  )
381
  yield chatbot, history, status_text, all_token_counts
382
 
 
389
  + colorama.Style.RESET_ALL
390
  )
391
 
392
+ if limited_context:
393
+ history = history[-4:]
394
+ all_token_counts = all_token_counts[-2:]
395
+ yield chatbot, history, status_text, all_token_counts
396
+
397
  if stream:
398
  max_token = max_token_streaming
399
  else:
modules/llama_func.py CHANGED
@@ -13,6 +13,8 @@ from llama_index import (
13
  from langchain.llms import OpenAI
14
  from langchain.chat_models import ChatOpenAI
15
  import colorama
 
 
16
 
17
  from modules.presets import *
18
  from modules.utils import *
@@ -29,6 +31,12 @@ def get_index_name(file_src):
29
 
30
  return md5_hash.hexdigest()
31
 
 
 
 
 
 
 
32
 
33
  def get_documents(file_src):
34
  documents = []
@@ -38,9 +46,12 @@ def get_documents(file_src):
38
  logging.info(f"loading file: {file.name}")
39
  if os.path.splitext(file.name)[1] == ".pdf":
40
  logging.debug("Loading PDF...")
41
- CJKPDFReader = download_loader("CJKPDFReader")
42
- loader = CJKPDFReader()
43
- text_raw = loader.load_data(file=file.name)[0].text
 
 
 
44
  elif os.path.splitext(file.name)[1] == ".docx":
45
  logging.debug("Loading DOCX...")
46
  DocxReader = download_loader("DocxReader")
@@ -56,6 +67,8 @@ def get_documents(file_src):
56
  with open(file.name, "r", encoding="utf-8") as f:
57
  text_raw = f.read()
58
  text = add_space(text_raw)
 
 
59
  documents += [Document(text)]
60
  logging.debug("Documents loaded.")
61
  return documents
@@ -65,13 +78,11 @@ def construct_index(
65
  api_key,
66
  file_src,
67
  max_input_size=4096,
68
- num_outputs=1,
69
  max_chunk_overlap=20,
70
  chunk_size_limit=600,
71
  embedding_limit=None,
72
- separator=" ",
73
- num_children=10,
74
- max_keywords_per_chunk=10,
75
  ):
76
  os.environ["OPENAI_API_KEY"] = api_key
77
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
@@ -81,14 +92,7 @@ def construct_index(
81
  llm_predictor = LLMPredictor(
82
  llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
83
  )
84
- prompt_helper = PromptHelper(
85
- max_input_size,
86
- num_outputs,
87
- max_chunk_overlap,
88
- embedding_limit,
89
- chunk_size_limit,
90
- separator=separator,
91
- )
92
  index_name = get_index_name(file_src)
93
  if os.path.exists(f"./index/{index_name}.json"):
94
  logging.info("找到了缓存的索引文件,加载中……")
@@ -97,7 +101,7 @@ def construct_index(
97
  try:
98
  documents = get_documents(file_src)
99
  logging.info("构建索引中……")
100
- service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
101
  index = GPTSimpleVectorIndex.from_documents(
102
  documents, service_context=service_context
103
  )
 
13
  from langchain.llms import OpenAI
14
  from langchain.chat_models import ChatOpenAI
15
  import colorama
16
+ import PyPDF2
17
+ from tqdm import tqdm
18
 
19
  from modules.presets import *
20
  from modules.utils import *
 
31
 
32
  return md5_hash.hexdigest()
33
 
34
+ def block_split(text):
35
+ blocks = []
36
+ while len(text) > 0:
37
+ blocks.append(Document(text[:1000]))
38
+ text = text[1000:]
39
+ return blocks
40
 
41
  def get_documents(file_src):
42
  documents = []
 
46
  logging.info(f"loading file: {file.name}")
47
  if os.path.splitext(file.name)[1] == ".pdf":
48
  logging.debug("Loading PDF...")
49
+ pdftext = ""
50
+ with open(file.name, 'rb') as pdfFileObj:
51
+ pdfReader = PyPDF2.PdfReader(pdfFileObj)
52
+ for page in tqdm(pdfReader.pages):
53
+ pdftext += page.extract_text()
54
+ text_raw = pdftext
55
  elif os.path.splitext(file.name)[1] == ".docx":
56
  logging.debug("Loading DOCX...")
57
  DocxReader = download_loader("DocxReader")
 
67
  with open(file.name, "r", encoding="utf-8") as f:
68
  text_raw = f.read()
69
  text = add_space(text_raw)
70
+ # text = block_split(text)
71
+ # documents += text
72
  documents += [Document(text)]
73
  logging.debug("Documents loaded.")
74
  return documents
 
78
  api_key,
79
  file_src,
80
  max_input_size=4096,
81
+ num_outputs=5,
82
  max_chunk_overlap=20,
83
  chunk_size_limit=600,
84
  embedding_limit=None,
85
+ separator=" "
 
 
86
  ):
87
  os.environ["OPENAI_API_KEY"] = api_key
88
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
 
92
  llm_predictor = LLMPredictor(
93
  llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
94
  )
95
+ prompt_helper = PromptHelper(max_input_size = max_input_size, num_output = num_outputs, max_chunk_overlap = max_chunk_overlap, embedding_limit=embedding_limit, chunk_size_limit=600, separator=separator)
 
 
 
 
 
 
 
96
  index_name = get_index_name(file_src)
97
  if os.path.exists(f"./index/{index_name}.json"):
98
  logging.info("找到了缓存的索引文件,加载中……")
 
101
  try:
102
  documents = get_documents(file_src)
103
  logging.info("构建索引中……")
104
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
105
  index = GPTSimpleVectorIndex.from_documents(
106
  documents, service_context=service_context
107
  )
modules/presets.py CHANGED
@@ -57,7 +57,8 @@ MODELS = [
57
  ] # 可选的模型
58
 
59
  REPLY_LANGUAGES = [
60
- "中文",
 
61
  "English",
62
  "日本語",
63
  "Español",
 
57
  ] # 可选的模型
58
 
59
  REPLY_LANGUAGES = [
60
+ "简体中文",
61
+ "繁體中文",
62
  "English",
63
  "日本語",
64
  "Español",
modules/utils.py CHANGED
@@ -375,8 +375,8 @@ def replace_today(prompt):
375
 
376
 
377
  def get_geoip():
378
- response = requests.get("https://ipapi.co/json/", timeout=5)
379
  try:
 
380
  data = response.json()
381
  except:
382
  data = {"error": True, "reason": "连接ipapi失败"}
@@ -384,7 +384,7 @@ def get_geoip():
384
  logging.warning(f"无法获取IP地址信息。\n{data}")
385
  if data["reason"] == "RateLimited":
386
  return (
387
- f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
388
  )
389
  else:
390
  return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
@@ -457,7 +457,7 @@ def get_proxies():
457
 
458
  if proxies == {}:
459
  proxies = None
460
-
461
  return proxies
462
 
463
  def run(command, desc=None, errdesc=None, custom_env=None, live=False):
@@ -500,4 +500,19 @@ Python: <span title="{sys.version}">{python_version}</span>
500
  Gradio: {gr.__version__}
501
   • 
502
  Commit: {commit_info}
503
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
 
377
  def get_geoip():
 
378
  try:
379
+ response = requests.get("https://ipapi.co/json/", timeout=5)
380
  data = response.json()
381
  except:
382
  data = {"error": True, "reason": "连接ipapi失败"}
 
384
  logging.warning(f"无法获取IP地址信息。\n{data}")
385
  if data["reason"] == "RateLimited":
386
  return (
387
+ f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用。"
388
  )
389
  else:
390
  return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
 
457
 
458
  if proxies == {}:
459
  proxies = None
460
+
461
  return proxies
462
 
463
  def run(command, desc=None, errdesc=None, custom_env=None, live=False):
 
500
  Gradio: {gr.__version__}
501
   • 
502
  Commit: {commit_info}
503
+ """
504
+
505
+ def add_source_numbers(lst, source_name = "Source", use_source = True):
506
+ if use_source:
507
+ return [f'[{idx+1}]\t "{item[0]}"\n{source_name}: {item[1]}' for idx, item in enumerate(lst)]
508
+ else:
509
+ return [f'[{idx+1}]\t "{item}"' for idx, item in enumerate(lst)]
510
+
511
+ def add_details(lst):
512
+ nodes = []
513
+ for index, txt in enumerate(lst):
514
+ brief = txt[:25].replace("\n", "")
515
+ nodes.append(
516
+ f"<details><summary>{brief}...</summary><p>{txt}</p></details>"
517
+ )
518
+ return nodes
requirements.txt CHANGED
@@ -10,3 +10,4 @@ Pygments
10
  llama_index
11
  langchain
12
  markdown
 
 
10
  llama_index
11
  langchain
12
  markdown
13
+ PyPDF2