TillLangbein commited on
Commit
cfa680c
1 Parent(s): 1edb596

reranker instead of a retrieval grader.

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app.py +64 -95
  3. prompts.py +9 -15
  4. requirements.txt +14 -1
.gitignore CHANGED
@@ -1 +1,2 @@
1
- test_env/
 
 
1
+ test_env/
2
+ .cache.db
app.py CHANGED
@@ -10,13 +10,17 @@ from langchain_openai import OpenAIEmbeddings
10
  from langgraph.graph import END, StateGraph, START
11
  from langchain_core.output_parsers import StrOutputParser
12
 
13
- import asyncio
14
  from typing import List
15
  from typing_extensions import TypedDict
16
  import gradio as gr
17
  from pydantic import BaseModel, Field
18
 
19
- from prompts import IMPROVE_PROMPT, RELEVANCE_PROMPT, ANSWER_PROMPT, HALLUCINATION_PROMPT, RESOLVER_PROMPT, REWRITER_PROMPT
 
 
 
 
 
20
 
21
  TOPICS = [
22
  "ICT strategy management",
@@ -47,13 +51,6 @@ TOPICS = [
47
  "ICT business continuity management"
48
  ]
49
 
50
- class GradeDocuments(BaseModel):
51
- """Binary score for relevance check on retrieved documents."""
52
-
53
- binary_score: str = Field(
54
- description="Documents are relevant to the question, 'yes' or 'no'"
55
- )
56
-
57
  class GradeHallucinations(BaseModel):
58
  """Binary score for hallucination present in generation answer."""
59
 
@@ -82,7 +79,6 @@ class GraphState(TypedDict):
82
  selected_sources: List[List[bool]]
83
  generation: str
84
  documents: List[str]
85
- fitting_documents: List[str]
86
  dora_docs: List[str]
87
  dora_rts_docs: List[str]
88
  dora_news_docs: List[str]
@@ -95,18 +91,24 @@ def _set_env(var: str):
95
  def load_vectorstores(paths: list):
96
  # The dora vectorstore
97
  embd = OpenAIEmbeddings()
 
 
98
 
99
  vectorstores = [FAISS.load_local(path, embd, allow_dangerous_deserialization=True) for path in paths]
100
- retrievers = [vectorstore.as_retriever(search_type="mmr", search_kwargs={
101
- "k": 7,
102
- "fetch_k": 10,
103
  "score_threshold": 0.7,
104
  }) for vectorstore in vectorstores]
105
 
 
 
 
 
106
  return retrievers
107
 
108
  # Put all chains in fuctions
109
- async def dora_rewrite(state):
110
  """
111
  Rewrites the question to fit dora wording
112
 
@@ -119,14 +121,14 @@ async def dora_rewrite(state):
119
  print("---TRANSLATE TO DORA---")
120
  question = state["question"]
121
 
122
- new_question = await dora_question_rewriter.ainvoke({"question": question, "topics": TOPICS})
123
 
124
  if new_question == "Thats an interesting question, but I dont think I can answer it based on my Dora knowledge.":
125
  return {"question": new_question, "generation": new_question}
126
  else:
127
  return {"question": new_question}
128
 
129
- async def retrieve(state):
130
  """
131
  Retrieve documents
132
 
@@ -141,52 +143,17 @@ async def retrieve(state):
141
  selected_sources = state["selected_sources"]
142
 
143
  # Retrieval
144
- documents = []
145
- if selected_sources[0]:
146
- documents.extend(await dora_retriever.ainvoke(question))
147
- if selected_sources[1]:
148
- documents.extend(await dora_rts_retriever.ainvoke(question))
149
- if selected_sources[2]:
150
- documents.extend(await dora_news_retriever.ainvoke(question))
151
-
152
- return {"documents": documents, "question": question}
153
 
154
- async def grade_documents(state):
155
- """
156
- Determines whether the retrieved documents are relevant to the question.
157
 
158
- Args:
159
- state (dict): The current graph state
160
 
161
- Returns:
162
- state (dict): Updates documents key with only filtered relevant documents
163
- """
164
-
165
- print("---CHECK DOCUMENTS RELEVANCE TO QUESTION---")
166
- question = state["question"]
167
- documents = state["documents"]
168
- fitting_documents = state["fitting_documents"] if "fitting_documents" in state else []
169
-
170
-
171
- # Score each doc
172
- for d in documents:
173
- score = await retrieval_grader.ainvoke(
174
- {"question": question, "document": d.page_content}
175
- )
176
- grade = score.binary_score
177
- if grade == "yes":
178
- #print("---GRADE: DOCUMENT RELEVANT---")
179
- if d in fitting_documents:
180
- #print(f"---Document {d.page_content} already in fitting documents---")
181
- continue
182
- fitting_documents.append(d)
183
- else:
184
- #print("---GRADE: DOCUMENT NOT RELEVANT---")
185
- continue
186
-
187
- return {"fitting_documents": fitting_documents}
188
-
189
- async def generate(state):
190
  """
191
  Generate answer
192
 
@@ -198,17 +165,13 @@ async def generate(state):
198
  """
199
  print("---GENERATE---")
200
  question = state["question"]
201
- fitting_documents = state["fitting_documents"]
202
-
203
- dora_docs = [d for d in fitting_documents if d.metadata["source"].startswith("Dora")]
204
- dora_rts_docs = [d for d in fitting_documents if d.metadata["source"].startswith("Commission")]
205
- dora_news_docs = [d for d in fitting_documents if d.metadata["source"].startswith("https")]
206
 
207
  # RAG generation
208
- generation = await answer_chain.ainvoke({"context": fitting_documents, "question": question})
209
- return {"generation": generation, "dora_docs": dora_docs, "dora_rts_docs": dora_rts_docs, "dora_news_docs": dora_news_docs}
210
 
211
- async def transform_query(state):
212
  """
213
  Transform the query to produce a better question.
214
 
@@ -223,12 +186,12 @@ async def transform_query(state):
223
  question = state["question"]
224
 
225
  # Re-write question
226
- better_question = await question_rewriter.ainvoke({"question": question})
227
  print(f"{better_question =}")
228
  return {"question": better_question}
229
 
230
  ### Edges ###
231
- async def suitable_question(state):
232
  """
233
  Determines whether the question is suitable.
234
 
@@ -247,7 +210,7 @@ async def suitable_question(state):
247
  else:
248
  return "retrieve"
249
 
250
- async def decide_to_generate(state):
251
  """
252
  Determines whether to generate an answer, or re-generate a question.
253
 
@@ -259,9 +222,9 @@ async def decide_to_generate(state):
259
  """
260
 
261
  print("---ASSESS GRADED DOCUMENTS---")
262
- fitting_documents = state["fitting_documents"]
263
 
264
- if not fitting_documents:
265
  # All documents have been filtered check_relevance
266
  # We will re-generate a new query
267
  print(
@@ -270,10 +233,10 @@ async def decide_to_generate(state):
270
  return "transform_query"
271
  else:
272
  # We have relevant documents, so generate answer
273
- print(f"---DECISION: GENERATE WITH {len(fitting_documents)} DOCUMENTS---")
274
  return "generate"
275
 
276
- async def grade_generation_v_documents_and_question(state):
277
  """
278
  Determines whether the generation is grounded in the document and answers question.
279
 
@@ -286,11 +249,11 @@ async def grade_generation_v_documents_and_question(state):
286
 
287
  print("---CHECK HALLUCINATIONS---")
288
  question = state["question"]
289
- fitting_documents = state["fitting_documents"]
290
  generation = state["generation"]
291
 
292
- score = await hallucination_grader.ainvoke(
293
- {"documents": fitting_documents, "generation": generation}
294
  )
295
  grade = score.binary_score
296
 
@@ -299,7 +262,7 @@ async def grade_generation_v_documents_and_question(state):
299
  print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
300
  # Check question-answering
301
  print("---GRADE GENERATION vs QUESTION---")
302
- score = await answer_grader.ainvoke({"question": question, "generation": generation})
303
  grade = score.binary_score
304
  if grade == "yes":
305
  print("---DECISION: GENERATION ADDRESSES QUESTION---")
@@ -308,7 +271,7 @@ async def grade_generation_v_documents_and_question(state):
308
  print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
309
  return "not useful"
310
  else:
311
- for document in fitting_documents:
312
  print(document.page_content)
313
  print("---DECISION: THOSE DOCUMENTS ARE NOT GROUNDING THIS GENERATION---")
314
  print(f"{generation = }")
@@ -318,11 +281,10 @@ async def grade_generation_v_documents_and_question(state):
318
  def compile_graph():
319
  workflow = StateGraph(GraphState)
320
  # Define the nodes
321
- workflow.add_node("dora_rewrite", dora_rewrite) # retrieve
322
- workflow.add_node("retrieve", retrieve) # retrieve
323
- workflow.add_node("grade_documents", grade_documents) # grade documents
324
- workflow.add_node("generate", generate) # generate
325
- workflow.add_node("transform_query", transform_query) # transform_query
326
  # Define the edges
327
  workflow.add_edge(START, "dora_rewrite")
328
  workflow.add_conditional_edges(
@@ -333,9 +295,8 @@ def compile_graph():
333
  "end": END,
334
  },
335
  )
336
- workflow.add_edge("retrieve", "grade_documents")
337
  workflow.add_conditional_edges(
338
- "grade_documents",
339
  decide_to_generate,
340
  {
341
  "transform_query": "transform_query",
@@ -357,9 +318,9 @@ def compile_graph():
357
  return app
358
 
359
  # Function to interact with Gradio
360
- async def generate_response(question: str, dora: bool, rts: bool, news: bool):
361
  selected_sources = [dora, rts, news] if any([dora, rts, news]) else [True, False, False]
362
- state = await app.ainvoke({"question": question, "selected_sources": selected_sources})
363
  return (
364
  state["generation"],
365
  ('\n\n'.join([f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_docs"]])) if "dora_docs" in state and state["dora_docs"] else 'No documents available.',
@@ -378,11 +339,21 @@ def clear_results():
378
 
379
  def random_prompt():
380
  return random.choice([
381
- "Was ist der Unterschied zwischen TIBER-EU und DORA TLPT?",
382
- "Ich möchte ein SIEM einführen. Bitte gib mir eine Checkliste, was ich beachten muss.",
383
- "Was ist der Geltungsbereich der DORA? Bin ich als Finanzdienstleister im Leasinggeschäft betroffen?",
384
- "Ich hatte einen Ransomwarevorfall mit erheblichen Auswirkungen auf den Geschäftsbetrieb. Muss ich etwas melden?",
385
- "Was ist dieses DORA überhaupt?"
 
 
 
 
 
 
 
 
 
 
386
  ])
387
 
388
  def load_css():
@@ -476,12 +447,10 @@ if __name__ == "__main__":
476
  )
477
 
478
  fast_llm = ChatOpenAI(model="gpt-3.5-turbo")
479
- smart_llm = ChatOpenAI(model="gpt-4-turbo", temperature=0.2, max_tokens=4096)
480
  tool_llm = ChatOpenAI(model="gpt-4o")
481
  rewrite_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=1, cache=False)
482
 
483
  dora_question_rewriter = IMPROVE_PROMPT | tool_llm | StrOutputParser()
484
- retrieval_grader = RELEVANCE_PROMPT | fast_llm.with_structured_output(GradeDocuments)
485
  answer_chain = ANSWER_PROMPT | tool_llm | StrOutputParser()
486
  hallucination_grader = HALLUCINATION_PROMPT | fast_llm.with_structured_output(GradeHallucinations)
487
  answer_grader = RESOLVER_PROMPT | fast_llm.with_structured_output(GradeAnswer)
 
10
  from langgraph.graph import END, StateGraph, START
11
  from langchain_core.output_parsers import StrOutputParser
12
 
 
13
  from typing import List
14
  from typing_extensions import TypedDict
15
  import gradio as gr
16
  from pydantic import BaseModel, Field
17
 
18
+ # For the reranking step
19
+ from langchain.retrievers import ContextualCompressionRetriever
20
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
21
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
22
+
23
+ from prompts import IMPROVE_PROMPT, ANSWER_PROMPT, HALLUCINATION_PROMPT, RESOLVER_PROMPT, REWRITER_PROMPT
24
 
25
  TOPICS = [
26
  "ICT strategy management",
 
51
  "ICT business continuity management"
52
  ]
53
 
 
 
 
 
 
 
 
54
  class GradeHallucinations(BaseModel):
55
  """Binary score for hallucination present in generation answer."""
56
 
 
79
  selected_sources: List[List[bool]]
80
  generation: str
81
  documents: List[str]
 
82
  dora_docs: List[str]
83
  dora_rts_docs: List[str]
84
  dora_news_docs: List[str]
 
91
  def load_vectorstores(paths: list):
92
  # The dora vectorstore
93
  embd = OpenAIEmbeddings()
94
+ model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
95
+ compressor = CrossEncoderReranker(model=model, top_n=7)
96
 
97
  vectorstores = [FAISS.load_local(path, embd, allow_dangerous_deserialization=True) for path in paths]
98
+ base_retrievers = [vectorstore.as_retriever(search_type="mmr", search_kwargs={
99
+ "k": 10,
100
+ "fetch_k": 20,
101
  "score_threshold": 0.7,
102
  }) for vectorstore in vectorstores]
103
 
104
+ retrievers = [ContextualCompressionRetriever(
105
+ base_compressor=compressor, base_retriever=retriever
106
+ ) for retriever in base_retrievers]
107
+
108
  return retrievers
109
 
110
  # Put all chains in fuctions
111
+ def dora_rewrite(state):
112
  """
113
  Rewrites the question to fit dora wording
114
 
 
121
  print("---TRANSLATE TO DORA---")
122
  question = state["question"]
123
 
124
+ new_question = dora_question_rewriter.invoke({"question": question, "topics": TOPICS})
125
 
126
  if new_question == "Thats an interesting question, but I dont think I can answer it based on my Dora knowledge.":
127
  return {"question": new_question, "generation": new_question}
128
  else:
129
  return {"question": new_question}
130
 
131
+ def retrieve(state):
132
  """
133
  Retrieve documents
134
 
 
143
  selected_sources = state["selected_sources"]
144
 
145
  # Retrieval
146
+
147
+ dora_docs = dora_retriever.invoke(question) if selected_sources[0] else []
148
+ dora_rts_docs = dora_rts_retriever.invoke(question) if selected_sources[1] else []
149
+ dora_news_docs = dora_news_retriever.invoke(question) if selected_sources[2] else []
150
+
151
+ documents = dora_docs + dora_rts_docs + dora_news_docs
 
 
 
152
 
153
+ return {"documents": documents, "dora_docs": dora_docs, "dora_rts_docs": dora_rts_docs, "dora_news_docs": dora_news_docs}
 
 
154
 
 
 
155
 
156
+ def generate(state):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  """
158
  Generate answer
159
 
 
165
  """
166
  print("---GENERATE---")
167
  question = state["question"]
168
+ documents = state["documents"]
 
 
 
 
169
 
170
  # RAG generation
171
+ generation = answer_chain.invoke({"context": documents, "question": question})
172
+ return {"generation": generation}
173
 
174
+ def transform_query(state):
175
  """
176
  Transform the query to produce a better question.
177
 
 
186
  question = state["question"]
187
 
188
  # Re-write question
189
+ better_question = question_rewriter.invoke({"question": question})
190
  print(f"{better_question =}")
191
  return {"question": better_question}
192
 
193
  ### Edges ###
194
+ def suitable_question(state):
195
  """
196
  Determines whether the question is suitable.
197
 
 
210
  else:
211
  return "retrieve"
212
 
213
+ def decide_to_generate(state):
214
  """
215
  Determines whether to generate an answer, or re-generate a question.
216
 
 
222
  """
223
 
224
  print("---ASSESS GRADED DOCUMENTS---")
225
+ documents = state["documents"]
226
 
227
+ if not documents:
228
  # All documents have been filtered check_relevance
229
  # We will re-generate a new query
230
  print(
 
233
  return "transform_query"
234
  else:
235
  # We have relevant documents, so generate answer
236
+ print(f"---DECISION: GENERATE WITH {len(documents)} DOCUMENTS---")
237
  return "generate"
238
 
239
+ def grade_generation_v_documents_and_question(state):
240
  """
241
  Determines whether the generation is grounded in the document and answers question.
242
 
 
249
 
250
  print("---CHECK HALLUCINATIONS---")
251
  question = state["question"]
252
+ documents = state["documents"]
253
  generation = state["generation"]
254
 
255
+ score = hallucination_grader.invoke(
256
+ {"documents": documents, "generation": generation}
257
  )
258
  grade = score.binary_score
259
 
 
262
  print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
263
  # Check question-answering
264
  print("---GRADE GENERATION vs QUESTION---")
265
+ score = answer_grader.invoke({"question": question, "generation": generation})
266
  grade = score.binary_score
267
  if grade == "yes":
268
  print("---DECISION: GENERATION ADDRESSES QUESTION---")
 
271
  print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
272
  return "not useful"
273
  else:
274
+ for document in documents:
275
  print(document.page_content)
276
  print("---DECISION: THOSE DOCUMENTS ARE NOT GROUNDING THIS GENERATION---")
277
  print(f"{generation = }")
 
281
  def compile_graph():
282
  workflow = StateGraph(GraphState)
283
  # Define the nodes
284
+ workflow.add_node("dora_rewrite", dora_rewrite)
285
+ workflow.add_node("retrieve", retrieve)
286
+ workflow.add_node("generate", generate)
287
+ workflow.add_node("transform_query", transform_query)
 
288
  # Define the edges
289
  workflow.add_edge(START, "dora_rewrite")
290
  workflow.add_conditional_edges(
 
295
  "end": END,
296
  },
297
  )
 
298
  workflow.add_conditional_edges(
299
+ "retrieve",
300
  decide_to_generate,
301
  {
302
  "transform_query": "transform_query",
 
318
  return app
319
 
320
  # Function to interact with Gradio
321
+ def generate_response(question: str, dora: bool, rts: bool, news: bool):
322
  selected_sources = [dora, rts, news] if any([dora, rts, news]) else [True, False, False]
323
+ state = app.invoke({"question": question, "selected_sources": selected_sources})
324
  return (
325
  state["generation"],
326
  ('\n\n'.join([f"***{doc.metadata['source']} section {doc.metadata['section']}***: {doc.page_content}" for doc in state["dora_docs"]])) if "dora_docs" in state and state["dora_docs"] else 'No documents available.',
 
339
 
340
  def random_prompt():
341
  return random.choice([
342
+ "How does DORA define critical ICT services and who must comply?",
343
+ "What are the key requirements for ICT risk management under DORA?",
344
+ "What are the reporting obligations under DORA for major incidents?",
345
+ "What third-party risk management requirements does DORA impose?",
346
+ "How does DORA's testing framework compare with the UK's CBEST framework?",
347
+ "Do ICT service providers fall under DORA's regulatory requirements?",
348
+ "How should I prepare for DORA's Threat-Led Penetration Testing (TLPT)?",
349
+ "What role do financial supervisors play in DORA compliance?",
350
+ "What penalties are applicable if an organization fails to comply with DORA?",
351
+ "How does DORA align with the NIS2 Directive in Europe?",
352
+ "Do insurance companies also fall under DORA's requirements?",
353
+ "What are the main differences between DORA and GDPR regarding incident reporting?",
354
+ "Are there specific resilience requirements for cloud service providers under DORA?",
355
+ "What are the main deadlines for compliance under DORA?",
356
+ "What steps should I take to ensure my third-party vendors are compliant with DORA?"
357
  ])
358
 
359
  def load_css():
 
447
  )
448
 
449
  fast_llm = ChatOpenAI(model="gpt-3.5-turbo")
 
450
  tool_llm = ChatOpenAI(model="gpt-4o")
451
  rewrite_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=1, cache=False)
452
 
453
  dora_question_rewriter = IMPROVE_PROMPT | tool_llm | StrOutputParser()
 
454
  answer_chain = ANSWER_PROMPT | tool_llm | StrOutputParser()
455
  hallucination_grader = HALLUCINATION_PROMPT | fast_llm.with_structured_output(GradeHallucinations)
456
  answer_grader = RESOLVER_PROMPT | fast_llm.with_structured_output(GradeAnswer)
prompts.py CHANGED
@@ -16,30 +16,24 @@ IMPROVE_PROMPT = ChatPromptTemplate.from_messages(
16
  ]
17
  )
18
 
19
- RELEVANCE_PROMPT = ChatPromptTemplate.from_messages(
20
- [
21
- ("system", """You are a grader assessing relevance of a retrieved document to a user question. \n
22
- If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
23
- It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
24
- Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
25
- ),
26
- ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
27
- ]
28
- )
29
-
30
  ANSWER_PROMPT = ChatPromptTemplate.from_messages(
31
  [
32
  (
33
  "system",
34
- "You are a highly experienced IT auditor, specializing in information security and regulatory compliance. Your task is to assist a colleague who has approached you with a question."
35
- " You have access to relevant context, provided here: {context}."
36
- " Please respond with a clear, concise, and precise answer, strictly based on the provided context. Ensure your response is accurate and always cite sources from the context."
37
- " Do not introduce any new information or alter the context in any way."
 
 
 
 
38
  ),
39
  ("user", "{question}"),
40
  ]
41
  )
42
 
 
43
  HALLUCINATION_PROMPT = ChatPromptTemplate.from_messages(
44
 
45
  [
 
16
  ]
17
  )
18
 
 
 
 
 
 
 
 
 
 
 
 
19
  ANSWER_PROMPT = ChatPromptTemplate.from_messages(
20
  [
21
  (
22
  "system",
23
+ "You are a highly experienced IT auditor, specializing in information security and regulatory compliance. "
24
+ "Your task is to assist a colleague who has approached you with a question. "
25
+ "You have access to relevant context, provided here: {context}. "
26
+ "Make your response as informative as possible and make sure every sentence is supported by the provided context."
27
+ "Each information must be backed up by a citation from at least one of the information sources in the context, formatted as a footnote, reproducing the source after your response."
28
+ "Your answer should be structured and suitable for regulatory documentation or audit reporting. "
29
+ "If you do not have a citation from the provided source material in the message, explicitly state: 'No citations found.' Never generate a citation if no source material is provided."
30
+ "Ensure all relevant details from the context are included in your response."
31
  ),
32
  ("user", "{question}"),
33
  ]
34
  )
35
 
36
+
37
  HALLUCINATION_PROMPT = ChatPromptTemplate.from_messages(
38
 
39
  [
requirements.txt CHANGED
@@ -6,4 +6,17 @@ langgraph==0.2.41
6
  pydantic==2.9.2
7
  typing_extensions==4.12.2
8
  faiss-cpu==1.8.0.post1
9
- asyncio==3.4.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  pydantic==2.9.2
7
  typing_extensions==4.12.2
8
  faiss-cpu==1.8.0.post1
9
+ asyncio==3.4.3
10
+ joblib==1.4.2
11
+ mpmath==1.3.0
12
+ networkx==3.4.2
13
+ safetensors==0.4.5
14
+ scikit-learn==1.5.2
15
+ scipy==1.14.1
16
+ sentence-transformers==3.3.1
17
+ setuptools==75.5.0
18
+ sympy==1.13.1
19
+ threadpoolctl==3.5.0
20
+ tokenizers==0.20.3
21
+ torch==2.5.1
22
+ transformers==4.46.2