File size: 2,957 Bytes
8df121c
a831acd
70b2fc9
05feb2b
 
70b2fc9
05feb2b
617bb16
 
70b2fc9
 
 
 
 
 
 
 
0b1473d
 
70b2fc9
8df121c
617bb16
 
8df121c
a831acd
617bb16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b1473d
617bb16
 
 
 
 
 
 
 
 
 
 
 
 
 
05feb2b
 
 
617bb16
05feb2b
 
0b1473d
 
 
 
 
 
 
 
 
 
617bb16
0b1473d
 
 
 
 
617bb16
0b1473d
 
8df121c
05feb2b
 
8df121c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import glob
import os
import logging
import sys

import streamlit as st
from haystack import Pipeline
from haystack.nodes import Shaper, PromptNode, PromptTemplate
from haystack.schema import Document

logging.basicConfig(
    level=logging.DEBUG,
    format="%(levelname)s %(asctime)s %(name)s:%(message)s",
    handlers=[logging.StreamHandler(sys.stdout)],
    force=True,
)

p_1 = None
p_2 = None


def get_plain_pipeline():
    prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=api_key)


    # Now let make one PromptNode use the default model and the other one the OpenAI model:
    plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: $query")
    node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)

    pipeline = Pipeline()
    pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
    return pipeline

def get_ret_aug_pipeline():
    ds = FAISSDocumentStore(faiss_index_path="my_faiss_index.faiss",
                            faiss_config_path="my_faiss_index.json")

    retriever = EmbeddingRetriever(
        document_store=ds,
        embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
        model_format="sentence_transformers",
        top_k=2
    )
    shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])

    default_template= PromptTemplate(
        name="question-answering",
        prompt_text="Given the context please answer the question. Context: $documents; Question: "
                    "$query; Answer:",
    )
    # Let's initiate the PromptNode
    node = PromptNode("text-davinci-003", default_prompt_template=default_template, api_key=api_key, max_length=500)

    # Let's create a pipeline with Shaper and PromptNode
    pipe = Pipeline()
    pipe.add_node(component=retriever, name='retriever', inputs=['Query'])
    pipe.add_node(component=shaper, name="shaper", inputs=["retriever"])
    pipe.add_node(component=node, name="prompt_node", inputs=["shaper"])
    return pipe


def app_init():

    os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
    p1 = get_plain_pipeline()
    p2 = get_ret_aug_pipeline()
    return p1, p2


def main():
    p1, p2 = app_init()
    st.title("Haystack Demo")
    input = st.text_input("Query ...")

    query_type = st.radio("Type",
                          ("Retrieval Augmented", "Retrieval Augmented with Sources",
                           "Retrieval Augmented with Web Search"))


    col_1, col_2 = st.columns(2)

    with col_1:
        st.text("PLAIN")
        answers = p1.run(input)["answers"]
        for ans in answers:
            st.text(ans.answer)

    with col_2:
        st.write(query_type.upper())
        answers = p2.run(input)["answers"]
        for ans in answers:
            st.text(ans.answer)


if __name__ == "__main__":
    main()