notSoNLPnerd's picture
latest 2 pipelines
617bb16
raw
history blame
2.96 kB
import glob
import os
import logging
import sys
import streamlit as st
from haystack import Pipeline
from haystack.nodes import Shaper, PromptNode, PromptTemplate
from haystack.schema import Document
logging.basicConfig(
level=logging.DEBUG,
format="%(levelname)s %(asctime)s %(name)s:%(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
force=True,
)
p_1 = None
p_2 = None
def get_plain_pipeline():
prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=api_key)
# Now let make one PromptNode use the default model and the other one the OpenAI model:
plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: $query")
node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)
pipeline = Pipeline()
pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
return pipeline
def get_ret_aug_pipeline():
ds = FAISSDocumentStore(faiss_index_path="my_faiss_index.faiss",
faiss_config_path="my_faiss_index.json")
retriever = EmbeddingRetriever(
document_store=ds,
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
model_format="sentence_transformers",
top_k=2
)
shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
default_template= PromptTemplate(
name="question-answering",
prompt_text="Given the context please answer the question. Context: $documents; Question: "
"$query; Answer:",
)
# Let's initiate the PromptNode
node = PromptNode("text-davinci-003", default_prompt_template=default_template, api_key=api_key, max_length=500)
# Let's create a pipeline with Shaper and PromptNode
pipe = Pipeline()
pipe.add_node(component=retriever, name='retriever', inputs=['Query'])
pipe.add_node(component=shaper, name="shaper", inputs=["retriever"])
pipe.add_node(component=node, name="prompt_node", inputs=["shaper"])
return pipe
def app_init():
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
p1 = get_plain_pipeline()
p2 = get_ret_aug_pipeline()
return p1, p2
def main():
p1, p2 = app_init()
st.title("Haystack Demo")
input = st.text_input("Query ...")
query_type = st.radio("Type",
("Retrieval Augmented", "Retrieval Augmented with Sources",
"Retrieval Augmented with Web Search"))
col_1, col_2 = st.columns(2)
with col_1:
st.text("PLAIN")
answers = p1.run(input)["answers"]
for ans in answers:
st.text(ans.answer)
with col_2:
st.write(query_type.upper())
answers = p2.run(input)["answers"]
for ans in answers:
st.text(ans.answer)
if __name__ == "__main__":
main()