xingfanxia commited on
Commit
31013be
1 Parent(s): 9f551dd

bugfix: fix gpt index 0.5.0 breaking changes

Browse files
Files changed (1) hide show
  1. modules/llama_func.py +18 -11
modules/llama_func.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import logging
3
 
4
- from llama_index import GPTSimpleVectorIndex
5
  from llama_index import download_loader
6
  from llama_index import (
7
  Document,
@@ -11,6 +11,7 @@ from llama_index import (
11
  RefinePrompt,
12
  )
13
  from langchain.llms import OpenAI
 
14
  import colorama
15
 
16
  from modules.presets import *
@@ -56,6 +57,7 @@ def get_documents(file_src):
56
  text_raw = f.read()
57
  text = add_space(text_raw)
58
  documents += [Document(text)]
 
59
  return documents
60
 
61
 
@@ -77,7 +79,7 @@ def construct_index(
77
  separator = " " if separator == "" else separator
78
 
79
  llm_predictor = LLMPredictor(
80
- llm=OpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
81
  )
82
  prompt_helper = PromptHelper(
83
  max_input_size,
@@ -94,13 +96,19 @@ def construct_index(
94
  else:
95
  try:
96
  documents = get_documents(file_src)
97
- logging.debug("构建索引中……")
98
- index = GPTSimpleVectorIndex(
99
- documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
100
- )
101
- os.makedirs("./index", exist_ok=True)
102
- index.save_to_disk(f"./index/{index_name}.json")
103
- return index
 
 
 
 
 
 
104
  except Exception as e:
105
  print(e)
106
  return None
@@ -158,7 +166,7 @@ def ask_ai(
158
  logging.debug("Index file found")
159
  logging.debug("Querying index...")
160
  llm_predictor = LLMPredictor(
161
- llm=OpenAI(
162
  temperature=temprature,
163
  model_name="gpt-3.5-turbo-0301",
164
  prefix_messages=prefix_messages,
@@ -170,7 +178,6 @@ def ask_ai(
170
  rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
171
  response = index.query(
172
  question,
173
- llm_predictor=llm_predictor,
174
  similarity_top_k=sim_k,
175
  text_qa_template=qa_prompt,
176
  refine_template=rf_prompt,
 
1
  import os
2
  import logging
3
 
4
+ from llama_index import GPTSimpleVectorIndex, ServiceContext
5
  from llama_index import download_loader
6
  from llama_index import (
7
  Document,
 
11
  RefinePrompt,
12
  )
13
  from langchain.llms import OpenAI
14
+ from langchain.chat_models import ChatOpenAI
15
  import colorama
16
 
17
  from modules.presets import *
 
57
  text_raw = f.read()
58
  text = add_space(text_raw)
59
  documents += [Document(text)]
60
+ logging.debug("Documents loaded.")
61
  return documents
62
 
63
 
 
79
  separator = " " if separator == "" else separator
80
 
81
  llm_predictor = LLMPredictor(
82
+ llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
83
  )
84
  prompt_helper = PromptHelper(
85
  max_input_size,
 
96
  else:
97
  try:
98
  documents = get_documents(file_src)
99
+ logging.info("构建索引中……")
100
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
101
+ try:
102
+ index = GPTSimpleVectorIndex.from_documents(
103
+ documents, service_context=service_context
104
+ )
105
+ logging.info("索引构建完成!")
106
+ os.makedirs("./index", exist_ok=True)
107
+ index.save_to_disk(f"./index/{index_name}.json")
108
+ logging.info("索引已保存至本地!")
109
+ return index
110
+ except Exception as e:
111
+ logging.error("索引构建失败!", e)
112
  except Exception as e:
113
  print(e)
114
  return None
 
166
  logging.debug("Index file found")
167
  logging.debug("Querying index...")
168
  llm_predictor = LLMPredictor(
169
+ llm=ChatOpenAI(
170
  temperature=temprature,
171
  model_name="gpt-3.5-turbo-0301",
172
  prefix_messages=prefix_messages,
 
178
  rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
179
  response = index.query(
180
  question,
 
181
  similarity_top_k=sim_k,
182
  text_qa_template=qa_prompt,
183
  refine_template=rf_prompt,