File size: 4,974 Bytes
4c6d98a
 
 
 
 
 
 
 
85affd8
21901c4
c29148e
4c6d98a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4271625
4c6d98a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8870bd5
4c6d98a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
from dotenv import load_dotenv
load_dotenv(".env")

os.environ['USER_AGENT'] = os.getenv("USER_AGENT")
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
os.environ["TOKENIZERS_PARALLELISM"]='true'

import nltk
nltk.download('punkt_tab')

from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory

from pinecone import Pinecone
from pinecone_text.sparse import BM25Encoder

from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.retrievers import PineconeHybridSearchRetriever

from langchain_groq import ChatGroq

import gradio as gr
import spaces
import torch


try:
    pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
    index_name = "traveler-demo-website-vectorstore"
    # connect to index
    pinecone_index = pc.Index(index_name)
except:
    pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
    index_name = "traveler-demo-website-vectorstore"
    # connect to index
    pinecone_index = pc.Index(index_name)

bm25 = BM25Encoder().load("./bm25_traveler_website.json")

embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True, 'device': 'cuda'})

retriever = PineconeHybridSearchRetriever(
    embeddings=embed_model, 
    sparse_encoder=bm25, 
    index=pinecone_index, 
    top_k=20, 
    alpha=0.5, 
)

llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.1, max_tokens=1024, max_retries=2)

### Contextualize question ###
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is.
"""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}")
    ]
)

history_aware_retriever = create_history_aware_retriever(
    llm, retriever, contextualize_q_prompt
)


qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following pieces of retrieved context to answer the question. \
Provide links to sources provided in the answer. \
If you don't know the answer, just say that you don't know. \
Do not give extra long answers. \
When responding to queries, your responses should be comprehensive and well-organized. For each response: \
    1. Provide Clear Answers \
    2. Include Detailed References: \
        - Include links to sources and any links or sites where there is a mentioned in the answer.
        - Links to Sources: Provide URLs to credible sources where users can verify the information or explore further. \
        - Downloadable Materials: Include links to any relevant downloadable resources if applicable. \
        - Reference Sites: Mention specific websites or platforms that offer additional information. \
    3. Formatting for Readability: \
        - Bullet Points or Lists: Where applicable, use bullet points or numbered lists to present information clearly. \
        - Emphasize Important Information: Use bold or italics to highlight key details. \
    4. Organize Content Logically \
Do not include anything about context in the answer. \
{context}
"""
qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", qa_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}")
    ]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

### Statefully manage chat history ###
store = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]


conversational_rag_chain = RunnableWithMessageHistory(
    rag_chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)

@spaces.GPU(duration=5)
def handle_message(question, history={}):
    response = ''
    chain = conversational_rag_chain.pick("answer")
    for chunk in chain.stream(
         {"input": question},
        config={
            "configurable": {"session_id": "abc123"}
        },
    ):
        response += chunk
        yield response

if __name__ == '__main__': 
    demo = gr.ChatInterface(fn=handle_message)
    demo.launch()