dh-mc commited on
Commit
ee3a625
1 Parent(s): 25309c9

added refine summary chain

Browse files
app_modules/init.py CHANGED
@@ -23,55 +23,59 @@ load_dotenv(found_dotenv, override=False)
23
  init_settings()
24
 
25
 
26
- def app_init():
27
  # https://github.com/huggingface/transformers/issues/17611
28
  os.environ["CURL_CA_BUNDLE"] = ""
29
 
 
 
 
30
  hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
31
  print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
32
  print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
33
 
34
- hf_embeddings_model_name = (
35
- os.environ.get("HF_EMBEDDINGS_MODEL_NAME") or "hkunlp/instructor-xl"
36
- )
37
-
38
- n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
39
- index_path = os.environ.get("FAISS_INDEX_PATH") or os.environ.get(
40
- "CHROMADB_INDEX_PATH"
41
- )
42
- using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
43
- llm_model_type = os.environ.get("LLM_MODEL_TYPE")
44
 
45
- start = timer()
46
- embeddings = HuggingFaceInstructEmbeddings(
47
- model_name=hf_embeddings_model_name,
48
- model_kwargs={"device": hf_embeddings_device_type},
49
- )
50
- end = timer()
51
 
52
- print(f"Completed in {end - start:.3f}s")
 
 
 
 
 
53
 
54
- start = timer()
55
 
56
- print(f"Load index from {index_path} with {'FAISS' if using_faiss else 'Chroma'}")
57
 
58
- if not os.path.isdir(index_path):
59
- raise ValueError(f"{index_path} does not exist!")
60
- elif using_faiss:
61
- vectorstore = FAISS.load_local(index_path, embeddings)
62
- else:
63
- vectorstore = Chroma(
64
- embedding_function=embeddings, persist_directory=index_path
65
  )
66
 
67
- end = timer()
 
 
 
 
 
 
 
68
 
69
- print(f"Completed in {end - start:.3f}s")
 
 
70
 
71
  start = timer()
72
  llm_loader = LLMLoader(llm_model_type)
73
  llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
74
- qa_chain = QAChain(vectorstore, llm_loader)
75
  end = timer()
76
  print(f"Completed in {end - start:.3f}s")
77
 
 
23
  init_settings()
24
 
25
 
26
+ def app_init(initQAChain: bool = True):
27
  # https://github.com/huggingface/transformers/issues/17611
28
  os.environ["CURL_CA_BUNDLE"] = ""
29
 
30
+ llm_model_type = os.environ.get("LLM_MODEL_TYPE")
31
+ n_threds = int(os.environ.get("NUMBER_OF_CPU_CORES") or "4")
32
+
33
  hf_embeddings_device_type, hf_pipeline_device_type = get_device_types()
34
  print(f"hf_embeddings_device_type: {hf_embeddings_device_type}")
35
  print(f"hf_pipeline_device_type: {hf_pipeline_device_type}")
36
 
37
+ if initQAChain:
38
+ hf_embeddings_model_name = (
39
+ os.environ.get("HF_EMBEDDINGS_MODEL_NAME") or "hkunlp/instructor-xl"
40
+ )
 
 
 
 
 
 
41
 
42
+ index_path = os.environ.get("FAISS_INDEX_PATH") or os.environ.get(
43
+ "CHROMADB_INDEX_PATH"
44
+ )
45
+ using_faiss = os.environ.get("FAISS_INDEX_PATH") is not None
 
 
46
 
47
+ start = timer()
48
+ embeddings = HuggingFaceInstructEmbeddings(
49
+ model_name=hf_embeddings_model_name,
50
+ model_kwargs={"device": hf_embeddings_device_type},
51
+ )
52
+ end = timer()
53
 
54
+ print(f"Completed in {end - start:.3f}s")
55
 
56
+ start = timer()
57
 
58
+ print(
59
+ f"Load index from {index_path} with {'FAISS' if using_faiss else 'Chroma'}"
 
 
 
 
 
60
  )
61
 
62
+ if not os.path.isdir(index_path):
63
+ raise ValueError(f"{index_path} does not exist!")
64
+ elif using_faiss:
65
+ vectorstore = FAISS.load_local(index_path, embeddings)
66
+ else:
67
+ vectorstore = Chroma(
68
+ embedding_function=embeddings, persist_directory=index_path
69
+ )
70
 
71
+ end = timer()
72
+
73
+ print(f"Completed in {end - start:.3f}s")
74
 
75
  start = timer()
76
  llm_loader = LLMLoader(llm_model_type)
77
  llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
78
+ qa_chain = QAChain(vectorstore, llm_loader) if initQAChain else None
79
  end = timer()
80
  print(f"Completed in {end - start:.3f}s")
81
 
app_modules/llm_summarize_chain.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+
4
+ from langchain.chains.base import Chain
5
+ from langchain.chains.summarize import load_summarize_chain
6
+
7
+ from app_modules.llm_inference import LLMInference
8
+
9
+
10
+ class SummarizeChain(LLMInference):
11
+ def __init__(self, llm_loader):
12
+ super().__init__(llm_loader)
13
+
14
+ def create_chain(self) -> Chain:
15
+ chain = load_summarize_chain(self.llm_loader.llm, chain_type="refine")
16
+ return chain
17
+
18
+ def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
19
+ result = chain(inputs, return_only_outputs=True)
20
+ return result
summarize.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # setting device on GPU if available, else CPU
2
+ import os
3
+ import sys
4
+ from timeit import default_timer as timer
5
+ from typing import List
6
+
7
+ from langchain.document_loaders import PyPDFDirectoryLoader
8
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain.vectorstores.base import VectorStore
11
+ from langchain.vectorstores.chroma import Chroma
12
+ from langchain.vectorstores.faiss import FAISS
13
+
14
+ from app_modules.init import app_init, get_device_types
15
+ from app_modules.llm_summarize_chain import SummarizeChain
16
+
17
+
18
+ def load_documents(source_pdfs_path, urls) -> List:
19
+ loader = PyPDFDirectoryLoader(source_pdfs_path, silent_errors=True)
20
+ documents = loader.load()
21
+ if urls is not None and len(urls) > 0:
22
+ for doc in documents:
23
+ source = doc.metadata["source"]
24
+ filename = source.split("/")[-1]
25
+ for url in urls:
26
+ if url.endswith(filename):
27
+ doc.metadata["url"] = url
28
+ break
29
+ return documents
30
+
31
+
32
+ def split_chunks(documents: List, chunk_size, chunk_overlap) -> List:
33
+ text_splitter = RecursiveCharacterTextSplitter(
34
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
35
+ )
36
+ return text_splitter.split_documents(documents)
37
+
38
+
39
+ llm_loader = app_init(False)[0]
40
+
41
+ source_pdfs_path = (
42
+ sys.argv[1] if len(sys.argv) > 1 else os.environ.get("SOURCE_PDFS_PATH")
43
+ )
44
+ chunk_size = os.environ.get("CHUNCK_SIZE")
45
+ chunk_overlap = os.environ.get("CHUNK_OVERLAP")
46
+
47
+ sources = load_documents(source_pdfs_path, None)
48
+
49
+ print(f"Splitting {len(sources)} PDF pages in to chunks ...")
50
+
51
+ chunks = split_chunks(
52
+ sources, chunk_size=int(chunk_size), chunk_overlap=int(chunk_overlap)
53
+ )
54
+
55
+ print(f"Summarizing {len(chunks)} chunks ...")
56
+ start = timer()
57
+
58
+ summarize_chain = SummarizeChain(llm_loader)
59
+ result = summarize_chain.call_chain(
60
+ {"input_documents": chunks},
61
+ None,
62
+ None,
63
+ True,
64
+ )
65
+
66
+ end = timer()
67
+ print(f"Completed in {end - start:.3f}s")
68
+
69
+ print("\n\n***Summary:")
70
+ print(result["output_text"])