Spaces:
Sleeping
Sleeping
amrohendawi
commited on
Commit
•
35c8ded
1
Parent(s):
03b2221
swapped query_pipeline with a chat-enabled inference_pipeline
Browse files- app.py +4 -3
- document_qa_engine.py +50 -30
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import os
|
2 |
from dotenv import load_dotenv
|
3 |
import pandas as pd
|
4 |
import streamlit as st
|
@@ -135,10 +134,12 @@ def display_chat_messages(chat_box, chat_input):
|
|
135 |
st.markdown(message["content"], unsafe_allow_html=True)
|
136 |
|
137 |
st.chat_message("user").markdown(chat_input)
|
138 |
-
st.session_state.messages.append({"role": "user", "content": chat_input})
|
139 |
with st.chat_message("assistant"):
|
140 |
-
|
|
|
|
|
141 |
st.markdown(response)
|
|
|
142 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
143 |
|
144 |
|
|
|
|
|
1 |
from dotenv import load_dotenv
|
2 |
import pandas as pd
|
3 |
import streamlit as st
|
|
|
134 |
st.markdown(message["content"], unsafe_allow_html=True)
|
135 |
|
136 |
st.chat_message("user").markdown(chat_input)
|
|
|
137 |
with st.chat_message("assistant"):
|
138 |
+
# process user input and generate response
|
139 |
+
response = st.session_state['document_qa_model'].inference(chat_input, st.session_state.messages)
|
140 |
+
|
141 |
st.markdown(response)
|
142 |
+
st.session_state.messages.append({"role": "user", "content": chat_input})
|
143 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
144 |
|
145 |
|
document_qa_engine.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
from typing import List
|
|
|
|
|
2 |
from pypdf import PdfReader
|
3 |
from haystack.utils import Secret
|
4 |
from haystack import Pipeline, Document, component
|
@@ -8,9 +10,8 @@ from haystack.components.writers import DocumentWriter
|
|
8 |
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
|
9 |
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
10 |
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
|
11 |
-
from haystack.components.builders import
|
12 |
from haystack.components.generators.chat import OpenAIChatGenerator, HuggingFaceTGIChatGenerator
|
13 |
-
from haystack.components.generators import OpenAIGenerator, HuggingFaceTGIGenerator
|
14 |
from haystack.document_stores.types import DuplicatePolicy
|
15 |
|
16 |
SENTENCE_RETREIVER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
@@ -70,34 +71,32 @@ def create_ingestion_pipeline(document_store):
|
|
70 |
return pipeline
|
71 |
|
72 |
|
73 |
-
def
|
74 |
-
prompt_builder = PromptBuilder(template=template)
|
75 |
if model_name == "local LLM":
|
76 |
-
generator =
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
elif "gpt" in model_name:
|
81 |
-
generator =
|
82 |
-
|
83 |
-
|
84 |
else:
|
85 |
-
generator =
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
query_pipeline.connect("prompt_builder", "generator")
|
99 |
|
100 |
-
return
|
101 |
|
102 |
|
103 |
class DocumentQAEngine:
|
@@ -109,12 +108,33 @@ class DocumentQAEngine:
|
|
109 |
self.model_name = model_name
|
110 |
document_store = InMemoryDocumentStore()
|
111 |
self.chunks = []
|
112 |
-
self.
|
113 |
self.pdf_ingestion_pipeline = create_ingestion_pipeline(document_store)
|
114 |
|
115 |
def ingest_pdf(self, uploaded_file):
|
116 |
self.pdf_ingestion_pipeline.run({"converter": {"uploaded_file": uploaded_file}})
|
117 |
|
118 |
-
def
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import List
|
2 |
+
|
3 |
+
from haystack.dataclasses import ChatMessage
|
4 |
from pypdf import PdfReader
|
5 |
from haystack.utils import Secret
|
6 |
from haystack import Pipeline, Document, component
|
|
|
10 |
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
|
11 |
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
12 |
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
|
13 |
+
from haystack.components.builders import DynamicChatPromptBuilder
|
14 |
from haystack.components.generators.chat import OpenAIChatGenerator, HuggingFaceTGIChatGenerator
|
|
|
15 |
from haystack.document_stores.types import DuplicatePolicy
|
16 |
|
17 |
SENTENCE_RETREIVER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
71 |
return pipeline
|
72 |
|
73 |
|
74 |
+
def create_inference_pipeline(document_store, model_name, api_key):
|
|
|
75 |
if model_name == "local LLM":
|
76 |
+
generator = OpenAIChatGenerator(model=model_name,
|
77 |
+
api_base_url="http://localhost:1234/v1",
|
78 |
+
generation_kwargs={"max_tokens": MAX_TOKENS}
|
79 |
+
)
|
80 |
elif "gpt" in model_name:
|
81 |
+
generator = OpenAIChatGenerator(api_key=Secret.from_token(api_key), model=model_name,
|
82 |
+
generation_kwargs={"max_tokens": MAX_TOKENS, "stream": False}
|
83 |
+
)
|
84 |
else:
|
85 |
+
generator = HuggingFaceTGIChatGenerator(token=Secret.from_token(api_key), model=model_name,
|
86 |
+
generation_kwargs={"max_new_tokens": MAX_TOKENS}
|
87 |
+
)
|
88 |
+
pipeline = Pipeline()
|
89 |
+
pipeline.add_component("text_embedder",
|
90 |
+
SentenceTransformersTextEmbedder(model=SENTENCE_RETREIVER_MODEL))
|
91 |
+
pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store, top_k=3))
|
92 |
+
pipeline.add_component("prompt_builder",
|
93 |
+
DynamicChatPromptBuilder(runtime_variables=["query", "documents"]))
|
94 |
+
pipeline.add_component("llm", generator)
|
95 |
+
pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
|
96 |
+
pipeline.connect("retriever.documents", "prompt_builder.documents")
|
97 |
+
pipeline.connect("prompt_builder.prompt", "llm.messages")
|
|
|
98 |
|
99 |
+
return pipeline
|
100 |
|
101 |
|
102 |
class DocumentQAEngine:
|
|
|
108 |
self.model_name = model_name
|
109 |
document_store = InMemoryDocumentStore()
|
110 |
self.chunks = []
|
111 |
+
self.inference_pipeline = create_inference_pipeline(document_store, model_name, api_key)
|
112 |
self.pdf_ingestion_pipeline = create_ingestion_pipeline(document_store)
|
113 |
|
114 |
def ingest_pdf(self, uploaded_file):
|
115 |
self.pdf_ingestion_pipeline.run({"converter": {"uploaded_file": uploaded_file}})
|
116 |
|
117 |
+
def inference(self, query, input_messages: List[dict]):
|
118 |
+
system_message = ChatMessage.from_system(
|
119 |
+
"You are a professional HR recruiter that answers questions based on the content of the uploaded CV. in 1 or 2 sentences.")
|
120 |
+
messages = [system_message]
|
121 |
+
for message in input_messages:
|
122 |
+
if message["role"] == "user":
|
123 |
+
messages.append(ChatMessage.from_system(message["content"]))
|
124 |
+
else:
|
125 |
+
messages.append(
|
126 |
+
ChatMessage.from_user(message["content"]))
|
127 |
+
messages.append(ChatMessage.from_user("""
|
128 |
+
Relevant information from the uploaded CV:
|
129 |
+
{% for doc in documents %}
|
130 |
+
{{ doc.content }}
|
131 |
+
{% endfor %}
|
132 |
+
|
133 |
+
\nQuestion: {{query}}
|
134 |
+
\nAnswer:
|
135 |
+
"""))
|
136 |
+
res = self.inference_pipeline.run(data={"text_embedder": {"text": query},
|
137 |
+
"prompt_builder": {"prompt_source": messages,
|
138 |
+
"query": query
|
139 |
+
}})
|
140 |
+
return res["llm"]["replies"][0].content
|