jeongsk commited on
Commit
e7055d3
1 Parent(s): 9539f2b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +163 -0
  2. laas.py +80 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+
4
+ import streamlit as st
5
+ from dotenv import load_dotenv
6
+ from laas import ChatLaaS
7
+ from langchain.embeddings import CacheBackedEmbeddings
8
+ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
9
+ from langchain.retrievers.document_compressors import (
10
+ CrossEncoderReranker,
11
+ FlashrankRerank,
12
+ )
13
+ from langchain_core.vectorstores import VectorStore
14
+ from langchain.storage import LocalFileStore
15
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
16
+ from langchain_community.document_loaders.generic import GenericLoader
17
+ from langchain_community.document_loaders.parsers.language.language_parser import (
18
+ LanguageParser,
19
+ )
20
+ from langchain_community.retrievers import BM25Retriever
21
+ from langchain_community.vectorstores import FAISS
22
+ from langchain_core.output_parsers import StrOutputParser
23
+ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
24
+ from langchain_huggingface import HuggingFaceEmbeddings
25
+ from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
26
+
27
+ # Load environment variables
28
+ load_dotenv()
29
+
30
+ # Set up environment variables
31
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
32
+ os.environ["LANGCHAIN_PROJECT"] = "Code QA Bot"
33
+
34
+
35
+ @st.cache_resource
36
+ def setup_embeddings_and_db(project_folder: str): # Note the underscore before 'docs'
37
+ CACHE_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".cache")
38
+ CACHE_MODELS_PATH = os.path.join(CACHE_ROOT_PATH, "models")
39
+ CACHE_EMBEDDINGS_PATH = os.path.join(CACHE_ROOT_PATH, "embeddings")
40
+
41
+ if not os.path.exists(CACHE_MODELS_PATH):
42
+ os.makedirs(CACHE_MODELS_PATH)
43
+ if not os.path.exists(CACHE_EMBEDDINGS_PATH):
44
+ os.makedirs(CACHE_EMBEDDINGS_PATH)
45
+
46
+ store = LocalFileStore(CACHE_EMBEDDINGS_PATH)
47
+
48
+ model_name = "BAAI/bge-m3"
49
+ model_kwargs = {"device": "mps"}
50
+ encode_kwargs = {"normalize_embeddings": False}
51
+ embeddings = HuggingFaceEmbeddings(
52
+ model_name=model_name,
53
+ model_kwargs=model_kwargs,
54
+ encode_kwargs=encode_kwargs,
55
+ cache_folder=CACHE_MODELS_PATH,
56
+ multi_process=False,
57
+ show_progress=True,
58
+ )
59
+
60
+ cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
61
+ embeddings,
62
+ store,
63
+ namespace=embeddings.model_name,
64
+ )
65
+
66
+ FAISS_DB_INDEX = os.path.join(project_folder, "langchain_faiss")
67
+ db = FAISS.load_local(
68
+ FAISS_DB_INDEX, # 로드할 FAISS 인덱스의 디렉토리 이름
69
+ cached_embeddings, # 임베딩 정보를 제공
70
+ allow_dangerous_deserialization=True, # 역직렬화를 허용하는 옵션
71
+ )
72
+
73
+ return db
74
+
75
+
76
+ # Function to set up retrievers and chain
77
+ @st.cache_resource
78
+ def setup_retrievers_and_chain(
79
+ _db: VectorStore, project_folder: str
80
+ ): # Note the underscores
81
+ faiss_retriever = _db.as_retriever(search_type="mmr", search_kwargs={"k": 20})
82
+
83
+ bm25_retriever_path = os.path.join(project_folder, "bm25_retriever.pkl")
84
+ with open(bm25_retriever_path, "rb") as f:
85
+ bm25_retriever = pickle.load(f)
86
+ bm25_retriever.k = 20
87
+
88
+ ensemble_retriever = EnsembleRetriever(
89
+ retrievers=[bm25_retriever, faiss_retriever],
90
+ weights=[0.6, 0.4],
91
+ search_type="mmr",
92
+ )
93
+
94
+ model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
95
+ compressor = CrossEncoderReranker(model=model, top_n=5)
96
+ compression_retriever = ContextualCompressionRetriever(
97
+ base_compressor=compressor,
98
+ base_retriever=ensemble_retriever,
99
+ )
100
+
101
+ laas = ChatLaaS(
102
+ project=st.secrets["LAAS_PROJECT"],
103
+ api_key=st.secrets["LAAS_API_KEY"],
104
+ hash=st.secrets["LAAS_HASH"],
105
+ )
106
+
107
+ rag_chain = (
108
+ {
109
+ "context": compression_retriever | RunnableLambda(lambda x: str(x)),
110
+ "question": RunnablePassthrough(),
111
+ }
112
+ | RunnableLambda(
113
+ lambda x: laas.invoke(
114
+ "", params={"context": x["context"], "question": x["question"]}
115
+ )
116
+ )
117
+ | StrOutputParser()
118
+ )
119
+
120
+ return rag_chain
121
+
122
+
123
+ def main():
124
+ st.title("Code QA Bot")
125
+
126
+ # Initialize session state for project folder and answer
127
+ if "project_folder" not in st.session_state:
128
+ st.session_state.project_folder = ""
129
+ if "answer" not in st.session_state:
130
+ st.session_state.answer = ""
131
+
132
+ # 프로젝트 경로 입력 받기
133
+ project_folder = st.text_input(
134
+ "Enter the project folder path:", value=st.session_state.project_folder
135
+ )
136
+ st.session_state.project_folder = project_folder
137
+
138
+ if project_folder:
139
+ # 프로젝트 경로가 입력되면 벡터 스토어와 체인 설정
140
+ db = setup_embeddings_and_db(project_folder)
141
+ rag_chain = setup_retrievers_and_chain(db, project_folder)
142
+
143
+ # 사용자 질문 입력 받기
144
+ user_question = st.text_input("Ask a question about the code:")
145
+
146
+ # Add a button to reset the answer
147
+ if st.button("Reset Answer"):
148
+ st.session_state.answer = ""
149
+
150
+ if user_question:
151
+ with st.spinner("Generating answer..."):
152
+ response = rag_chain.invoke(user_question)
153
+ st.session_state.answer = response
154
+
155
+ # Display the answer
156
+ if st.session_state.answer:
157
+ st.write(st.session_state.answer)
158
+ else:
159
+ st.warning("Please enter the project folder path to proceed.")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()
laas.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, List, Optional
3
+
4
+ import requests
5
+ from langchain_core.callbacks import CallbackManagerForLLMRun
6
+ from langchain_core.language_models import BaseChatModel, BaseLanguageModel
7
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
8
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
9
+ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class ChatLaaS(BaseChatModel):
15
+ laas_api_base: Optional[str] = Field(
16
+ default="https://api-laas.wanted.co.kr/api/preset", alias="base_url"
17
+ )
18
+ laas_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
19
+ laas_project: Optional[str] = Field(default=None, alias="project")
20
+ laas_hash: Optional[str] = Field(default=None, alias="hash")
21
+ timeout: Optional[float] = Field(default=60.0)
22
+
23
+ _ROLE_MAP = {
24
+ "human": "user",
25
+ "ai": "assistant",
26
+ }
27
+
28
+ @property
29
+ def _llm_type(self) -> str:
30
+ """Return type of chat model."""
31
+ return "laas-chat"
32
+
33
+ @classmethod
34
+ def is_lc_serializable(cls) -> bool:
35
+ """Return whether this model can be serialized by Langchain."""
36
+ return False
37
+
38
+ def _generate(
39
+ self,
40
+ messages: List[BaseMessage],
41
+ stop: Optional[List[str]] = None,
42
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
43
+ **kwargs: Any,
44
+ ) -> ChatResult:
45
+ try:
46
+ body = {
47
+ "hash": self.laas_hash,
48
+ "messages": [
49
+ {
50
+ "role": self._ROLE_MAP.get(msg.type, msg.type),
51
+ "content": msg.content,
52
+ }
53
+ for msg in messages
54
+ if msg.content.strip() # This filters out empty or whitespace-only content
55
+ ],
56
+ **kwargs,
57
+ }
58
+ print(body)
59
+ # return
60
+
61
+ headers = {
62
+ "Content-Type": "application/json",
63
+ "apiKey": self.laas_api_key.get_secret_value(),
64
+ "project": self.laas_project,
65
+ }
66
+
67
+ response = requests.post(
68
+ f"{self.laas_api_base}/chat/completions",
69
+ headers=headers,
70
+ json=body,
71
+ timeout=self.timeout,
72
+ ).json()
73
+
74
+ # Extract the content from the API response
75
+ content = response["choices"][0]["message"]["content"]
76
+ message = AIMessage(id=response["id"], content=content)
77
+ generation = ChatGeneration(message=message)
78
+ return ChatResult(generations=[generation])
79
+ except Exception as e:
80
+ raise