amrohendawi commited on
Commit
35c8ded
1 Parent(s): 03b2221

swapped query_pipeline with a chat-enabled inference_pipeline

Browse files
Files changed (2) hide show
  1. app.py +4 -3
  2. document_qa_engine.py +50 -30
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  from dotenv import load_dotenv
3
  import pandas as pd
4
  import streamlit as st
@@ -135,10 +134,12 @@ def display_chat_messages(chat_box, chat_input):
135
  st.markdown(message["content"], unsafe_allow_html=True)
136
 
137
  st.chat_message("user").markdown(chat_input)
138
- st.session_state.messages.append({"role": "user", "content": chat_input})
139
  with st.chat_message("assistant"):
140
- response = st.session_state['document_qa_model'].process_message(chat_input)
 
 
141
  st.markdown(response)
 
142
  st.session_state.messages.append({"role": "assistant", "content": response})
143
 
144
 
 
 
1
  from dotenv import load_dotenv
2
  import pandas as pd
3
  import streamlit as st
 
134
  st.markdown(message["content"], unsafe_allow_html=True)
135
 
136
  st.chat_message("user").markdown(chat_input)
 
137
  with st.chat_message("assistant"):
138
+ # process user input and generate response
139
+ response = st.session_state['document_qa_model'].inference(chat_input, st.session_state.messages)
140
+
141
  st.markdown(response)
142
+ st.session_state.messages.append({"role": "user", "content": chat_input})
143
  st.session_state.messages.append({"role": "assistant", "content": response})
144
 
145
 
document_qa_engine.py CHANGED
@@ -1,4 +1,6 @@
1
  from typing import List
 
 
2
  from pypdf import PdfReader
3
  from haystack.utils import Secret
4
  from haystack import Pipeline, Document, component
@@ -8,9 +10,8 @@ from haystack.components.writers import DocumentWriter
8
  from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
9
  from haystack.document_stores.in_memory import InMemoryDocumentStore
10
  from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
11
- from haystack.components.builders import PromptBuilder
12
  from haystack.components.generators.chat import OpenAIChatGenerator, HuggingFaceTGIChatGenerator
13
- from haystack.components.generators import OpenAIGenerator, HuggingFaceTGIGenerator
14
  from haystack.document_stores.types import DuplicatePolicy
15
 
16
  SENTENCE_RETREIVER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
@@ -70,34 +71,32 @@ def create_ingestion_pipeline(document_store):
70
  return pipeline
71
 
72
 
73
- def create_query_pipeline(document_store, model_name, api_key):
74
- prompt_builder = PromptBuilder(template=template)
75
  if model_name == "local LLM":
76
- generator = OpenAIGenerator(model=model_name,
77
- api_base_url="http://localhost:1234/v1",
78
- generation_kwargs={"max_tokens": MAX_TOKENS}
79
- )
80
  elif "gpt" in model_name:
81
- generator = OpenAIGenerator(api_key=Secret.from_token(api_key), model=model_name,
82
- generation_kwargs={"max_tokens": MAX_TOKENS}
83
- )
84
  else:
85
- generator = HuggingFaceTGIGenerator(token=Secret.from_token(api_key), model=model_name,
86
- generation_kwargs={"max_new_tokens": MAX_TOKENS}
87
- )
88
-
89
- query_pipeline = Pipeline()
90
- query_pipeline.add_component("text_embedder",
91
- SentenceTransformersTextEmbedder(model=SENTENCE_RETREIVER_MODEL))
92
- query_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store, top_k=3))
93
- query_pipeline.add_component("prompt_builder", prompt_builder)
94
- query_pipeline.add_component("generator", generator)
95
-
96
- query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
97
- query_pipeline.connect("retriever.documents", "prompt_builder.documents")
98
- query_pipeline.connect("prompt_builder", "generator")
99
 
100
- return query_pipeline
101
 
102
 
103
  class DocumentQAEngine:
@@ -109,12 +108,33 @@ class DocumentQAEngine:
109
  self.model_name = model_name
110
  document_store = InMemoryDocumentStore()
111
  self.chunks = []
112
- self.query_pipeline = create_query_pipeline(document_store, model_name, api_key)
113
  self.pdf_ingestion_pipeline = create_ingestion_pipeline(document_store)
114
 
115
  def ingest_pdf(self, uploaded_file):
116
  self.pdf_ingestion_pipeline.run({"converter": {"uploaded_file": uploaded_file}})
117
 
118
- def process_message(self, query):
119
- response = self.query_pipeline.run({"text_embedder": {"text": query}, "prompt_builder": {"question": query}})
120
- return response["generator"]["replies"][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List
2
+
3
+ from haystack.dataclasses import ChatMessage
4
  from pypdf import PdfReader
5
  from haystack.utils import Secret
6
  from haystack import Pipeline, Document, component
 
10
  from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
11
  from haystack.document_stores.in_memory import InMemoryDocumentStore
12
  from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
13
+ from haystack.components.builders import DynamicChatPromptBuilder
14
  from haystack.components.generators.chat import OpenAIChatGenerator, HuggingFaceTGIChatGenerator
 
15
  from haystack.document_stores.types import DuplicatePolicy
16
 
17
  SENTENCE_RETREIVER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
 
71
  return pipeline
72
 
73
 
74
+ def create_inference_pipeline(document_store, model_name, api_key):
 
75
  if model_name == "local LLM":
76
+ generator = OpenAIChatGenerator(model=model_name,
77
+ api_base_url="http://localhost:1234/v1",
78
+ generation_kwargs={"max_tokens": MAX_TOKENS}
79
+ )
80
  elif "gpt" in model_name:
81
+ generator = OpenAIChatGenerator(api_key=Secret.from_token(api_key), model=model_name,
82
+ generation_kwargs={"max_tokens": MAX_TOKENS, "stream": False}
83
+ )
84
  else:
85
+ generator = HuggingFaceTGIChatGenerator(token=Secret.from_token(api_key), model=model_name,
86
+ generation_kwargs={"max_new_tokens": MAX_TOKENS}
87
+ )
88
+ pipeline = Pipeline()
89
+ pipeline.add_component("text_embedder",
90
+ SentenceTransformersTextEmbedder(model=SENTENCE_RETREIVER_MODEL))
91
+ pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store, top_k=3))
92
+ pipeline.add_component("prompt_builder",
93
+ DynamicChatPromptBuilder(runtime_variables=["query", "documents"]))
94
+ pipeline.add_component("llm", generator)
95
+ pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
96
+ pipeline.connect("retriever.documents", "prompt_builder.documents")
97
+ pipeline.connect("prompt_builder.prompt", "llm.messages")
 
98
 
99
+ return pipeline
100
 
101
 
102
  class DocumentQAEngine:
 
108
  self.model_name = model_name
109
  document_store = InMemoryDocumentStore()
110
  self.chunks = []
111
+ self.inference_pipeline = create_inference_pipeline(document_store, model_name, api_key)
112
  self.pdf_ingestion_pipeline = create_ingestion_pipeline(document_store)
113
 
114
  def ingest_pdf(self, uploaded_file):
115
  self.pdf_ingestion_pipeline.run({"converter": {"uploaded_file": uploaded_file}})
116
 
117
+ def inference(self, query, input_messages: List[dict]):
118
+ system_message = ChatMessage.from_system(
119
+ "You are a professional HR recruiter that answers questions based on the content of the uploaded CV. in 1 or 2 sentences.")
120
+ messages = [system_message]
121
+ for message in input_messages:
122
+ if message["role"] == "user":
123
+ messages.append(ChatMessage.from_system(message["content"]))
124
+ else:
125
+ messages.append(
126
+ ChatMessage.from_user(message["content"]))
127
+ messages.append(ChatMessage.from_user("""
128
+ Relevant information from the uploaded CV:
129
+ {% for doc in documents %}
130
+ {{ doc.content }}
131
+ {% endfor %}
132
+
133
+ \nQuestion: {{query}}
134
+ \nAnswer:
135
+ """))
136
+ res = self.inference_pipeline.run(data={"text_embedder": {"text": query},
137
+ "prompt_builder": {"prompt_source": messages,
138
+ "query": query
139
+ }})
140
+ return res["llm"]["replies"][0].content