JohnSmith9982
commited on
Commit
•
44846b2
1
Parent(s):
8ad9e26
Upload 7 files
Browse files- modules/chat_func.py +2 -2
- modules/llama_func.py +42 -38
modules/chat_func.py
CHANGED
@@ -155,7 +155,7 @@ def stream_predict(
|
|
155 |
yield get_return_value()
|
156 |
error_json_str = ""
|
157 |
|
158 |
-
for chunk in response.iter_lines():
|
159 |
if counter == 0:
|
160 |
counter += 1
|
161 |
continue
|
@@ -272,7 +272,7 @@ def predict(
|
|
272 |
if reply_language == "跟随问题语言(不稳定)":
|
273 |
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
274 |
if files:
|
275 |
-
msg = "
|
276 |
logging.info(msg)
|
277 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
278 |
index = construct_index(openai_api_key, file_src=files)
|
|
|
155 |
yield get_return_value()
|
156 |
error_json_str = ""
|
157 |
|
158 |
+
for chunk in tqdm(response.iter_lines()):
|
159 |
if counter == 0:
|
160 |
counter += 1
|
161 |
continue
|
|
|
272 |
if reply_language == "跟随问题语言(不稳定)":
|
273 |
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
274 |
if files:
|
275 |
+
msg = "加载索引中……(这可能需要几分钟)"
|
276 |
logging.info(msg)
|
277 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
278 |
index = construct_index(openai_api_key, file_src=files)
|
modules/llama_func.py
CHANGED
@@ -13,54 +13,57 @@ from llama_index import (
|
|
13 |
from langchain.llms import OpenAI
|
14 |
import colorama
|
15 |
|
16 |
-
|
17 |
from modules.presets import *
|
18 |
from modules.utils import *
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def get_documents(file_src):
|
22 |
documents = []
|
23 |
-
index_name = ""
|
24 |
logging.debug("Loading documents...")
|
25 |
logging.debug(f"file_src: {file_src}")
|
26 |
for file in file_src:
|
27 |
-
logging.
|
28 |
-
index_name += file.name
|
29 |
if os.path.splitext(file.name)[1] == ".pdf":
|
30 |
logging.debug("Loading PDF...")
|
31 |
CJKPDFReader = download_loader("CJKPDFReader")
|
32 |
loader = CJKPDFReader()
|
33 |
-
|
34 |
elif os.path.splitext(file.name)[1] == ".docx":
|
35 |
logging.debug("Loading DOCX...")
|
36 |
DocxReader = download_loader("DocxReader")
|
37 |
loader = DocxReader()
|
38 |
-
|
39 |
elif os.path.splitext(file.name)[1] == ".epub":
|
40 |
logging.debug("Loading EPUB...")
|
41 |
EpubReader = download_loader("EpubReader")
|
42 |
loader = EpubReader()
|
43 |
-
|
44 |
else:
|
45 |
logging.debug("Loading text file...")
|
46 |
with open(file.name, "r", encoding="utf-8") as f:
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
return documents
|
51 |
|
52 |
|
53 |
def construct_index(
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
):
|
65 |
os.environ["OPENAI_API_KEY"] = api_key
|
66 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
@@ -78,12 +81,13 @@ def construct_index(
|
|
78 |
chunk_size_limit,
|
79 |
separator=separator,
|
80 |
)
|
81 |
-
|
82 |
if os.path.exists(f"./index/{index_name}.json"):
|
83 |
logging.info("找到了缓存的索引文件,加载中……")
|
84 |
return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
|
85 |
else:
|
86 |
try:
|
|
|
87 |
logging.debug("构建索引中……")
|
88 |
index = GPTSimpleVectorIndex(
|
89 |
documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
|
@@ -97,12 +101,12 @@ def construct_index(
|
|
97 |
|
98 |
|
99 |
def chat_ai(
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
):
|
107 |
os.environ["OPENAI_API_KEY"] = api_key
|
108 |
|
@@ -133,15 +137,15 @@ def chat_ai(
|
|
133 |
|
134 |
|
135 |
def ask_ai(
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
):
|
146 |
os.environ["OPENAI_API_KEY"] = api_key
|
147 |
|
@@ -174,7 +178,7 @@ def ask_ai(
|
|
174 |
for index, node in enumerate(response.source_nodes):
|
175 |
brief = node.source_text[:25].replace("\n", "")
|
176 |
nodes.append(
|
177 |
-
f"<details><summary>[{index+1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
|
178 |
)
|
179 |
new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
|
180 |
logging.info(
|
|
|
13 |
from langchain.llms import OpenAI
|
14 |
import colorama
|
15 |
|
|
|
16 |
from modules.presets import *
|
17 |
from modules.utils import *
|
18 |
|
19 |
+
def get_index_name(file_src):
|
20 |
+
index_name = ""
|
21 |
+
for file in file_src:
|
22 |
+
index_name += os.path.basename(file.name)
|
23 |
+
index_name = sha1sum(index_name)
|
24 |
+
return index_name
|
25 |
|
26 |
def get_documents(file_src):
|
27 |
documents = []
|
|
|
28 |
logging.debug("Loading documents...")
|
29 |
logging.debug(f"file_src: {file_src}")
|
30 |
for file in file_src:
|
31 |
+
logging.info(f"loading file: {file.name}")
|
|
|
32 |
if os.path.splitext(file.name)[1] == ".pdf":
|
33 |
logging.debug("Loading PDF...")
|
34 |
CJKPDFReader = download_loader("CJKPDFReader")
|
35 |
loader = CJKPDFReader()
|
36 |
+
text_raw = loader.load_data(file=file.name)[0].text
|
37 |
elif os.path.splitext(file.name)[1] == ".docx":
|
38 |
logging.debug("Loading DOCX...")
|
39 |
DocxReader = download_loader("DocxReader")
|
40 |
loader = DocxReader()
|
41 |
+
text_raw = loader.load_data(file=file.name)[0].text
|
42 |
elif os.path.splitext(file.name)[1] == ".epub":
|
43 |
logging.debug("Loading EPUB...")
|
44 |
EpubReader = download_loader("EpubReader")
|
45 |
loader = EpubReader()
|
46 |
+
text_raw = loader.load_data(file=file.name)[0].text
|
47 |
else:
|
48 |
logging.debug("Loading text file...")
|
49 |
with open(file.name, "r", encoding="utf-8") as f:
|
50 |
+
text_raw = f.read()
|
51 |
+
text = add_space(text_raw)
|
52 |
+
documents += [Document(text)]
|
53 |
+
return documents
|
54 |
|
55 |
|
56 |
def construct_index(
|
57 |
+
api_key,
|
58 |
+
file_src,
|
59 |
+
max_input_size=4096,
|
60 |
+
num_outputs=1,
|
61 |
+
max_chunk_overlap=20,
|
62 |
+
chunk_size_limit=600,
|
63 |
+
embedding_limit=None,
|
64 |
+
separator=" ",
|
65 |
+
num_children=10,
|
66 |
+
max_keywords_per_chunk=10,
|
67 |
):
|
68 |
os.environ["OPENAI_API_KEY"] = api_key
|
69 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
|
|
81 |
chunk_size_limit,
|
82 |
separator=separator,
|
83 |
)
|
84 |
+
index_name = get_index_name(file_src)
|
85 |
if os.path.exists(f"./index/{index_name}.json"):
|
86 |
logging.info("找到了缓存的索引文件,加载中……")
|
87 |
return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
|
88 |
else:
|
89 |
try:
|
90 |
+
documents = get_documents(file_src)
|
91 |
logging.debug("构建索引中……")
|
92 |
index = GPTSimpleVectorIndex(
|
93 |
documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
|
|
|
101 |
|
102 |
|
103 |
def chat_ai(
|
104 |
+
api_key,
|
105 |
+
index,
|
106 |
+
question,
|
107 |
+
context,
|
108 |
+
chatbot,
|
109 |
+
reply_language,
|
110 |
):
|
111 |
os.environ["OPENAI_API_KEY"] = api_key
|
112 |
|
|
|
137 |
|
138 |
|
139 |
def ask_ai(
|
140 |
+
api_key,
|
141 |
+
index,
|
142 |
+
question,
|
143 |
+
prompt_tmpl,
|
144 |
+
refine_tmpl,
|
145 |
+
sim_k=1,
|
146 |
+
temprature=0,
|
147 |
+
prefix_messages=[],
|
148 |
+
reply_language="中文",
|
149 |
):
|
150 |
os.environ["OPENAI_API_KEY"] = api_key
|
151 |
|
|
|
178 |
for index, node in enumerate(response.source_nodes):
|
179 |
brief = node.source_text[:25].replace("\n", "")
|
180 |
nodes.append(
|
181 |
+
f"<details><summary>[{index + 1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
|
182 |
)
|
183 |
new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
|
184 |
logging.info(
|