File size: 6,331 Bytes
fe1fc2e
467c73a
3c4de7d
bd26a11
08a925d
bd26a11
 
 
8356c3c
bd26a11
 
 
467c73a
fe1fc2e
8356c3c
dfb65c1
a76c41e
 
08a925d
fcd5615
10edf7a
7c9ee97
 
 
 
 
92bd8e5
7c9ee97
e1ca13f
0772e8b
7c9ee97
 
40e10a3
305c673
c81e469
8e41c14
6684bbc
10edf7a
 
 
 
 
 
152b9b0
ca65ca1
b0188c2
5d24867
 
 
 
 
 
 
33ab765
5d24867
91fbf6f
 
ad5fff5
 
91fbf6f
5d24867
152b9b0
0420d34
5d24867
152b9b0
 
33ab765
de0cc6e
de7666e
152b9b0
fe815e5
55deafd
de0cc6e
ea1938d
d1e59aa
 
 
 
 
 
152b9b0
d1e59aa
f9b8984
cd7cb8b
f9b8984
886eded
55e669b
229db9f
d1e59aa
5aa5326
886eded
6c165fe
5a5165b
2bc8ef5
cd7cb8b
 
 
 
 
36b9a37
2ecb0e9
ad5fff5
2ecb0e9
ad5fff5
2ecb0e9
 
 
 
2c751e4
 
2ecb0e9
ad5fff5
660cd45
 
 
e42b949
 
6a474c7
443a65d
e42b949
443a65d
 
799176a
 
ce39389
799176a
443a65d
261a9b2
03d617b
261a9b2
9f3b8b8
 
e42b949
687dbd6
 
22a7f5f
8e41c14
 
 
 
 
 
2bc8ef5
 
 
e42b949
3d410fb
687dbd6
 
9e61368
8e41c14
9e61368
 
687dbd6
9f3b8b8
fd0bd52
2bc8ef5
05d1ad8
dedb405
fe1fc2e
3c4de7d
c547536
fcd5615
0c95628
0007508
ab2f443
6c165fe
98c0504
8e29761
6c165fe
 
 
55926d9
8356c3c
 
fe1fc2e
fcd5615
08a925d
3a79f80
8eab994
b05e06e
dd38b99
a34af5b
b279902
ab1db22
b279902
b2ec3a4
fe1fc2e
fcd5615
fe1fc2e
81c159a
fe8eb6d
5bd324a
815187b
10edf7a
815187b
e1ca13f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import streamlit as st
from dotenv import load_dotenv
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import llamacpp, LlamaCpp
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
from langchain.chains import create_history_aware_retriever, create_retrieval_chain, ConversationalRetrievalChain
from langchain.document_loaders import TextLoader
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories.streamlit import StreamlitChatMessageHistory
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Chroma
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
from langchain_community.document_loaders.directory import DirectoryLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_experimental.chat_models import Llama2Chat
from langchain_community.chat_models.llamacpp import ChatLlamaCpp

lang_api_key = os.getenv("lang_api_key")

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.langchain.plus"
os.environ["LANGCHAIN_API_KEY"] = lang_api_key
os.environ["LANGCHAIN_PROJECT"] = "Lithuanian_Law_RAG_QA"





def create_retriever_from_chroma(vectorstore_path="./docs/chroma/", search_type='mmr', k=7, chunk_size=300, chunk_overlap=30,lambda_mult= 0.7):
    
    model_name = "Alibaba-NLP/gte-large-en-v1.5"
    model_kwargs = {'device': 'cpu',
                   "trust_remote_code" : 'False'}
    encode_kwargs = {'normalize_embeddings': True}
    embeddings = HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs
    )

    

    if os.path.exists(vectorstore_path) and os.listdir(vectorstore_path):
        vectorstore = Chroma(persist_directory=vectorstore_path,embedding_function=embeddings)
        
    else:
        st.write("Vector store doesnt exist and will be created now")
        loader = DirectoryLoader('./data/', glob="./*.txt", loader_cls=TextLoader)
        docs = loader.load()
        
        
        text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap,
        separators=["\n\n \n\n","\n\n\n", "\n\n", r"In \[[0-9]+\]", r"\n+", r"\s+"],
        is_separator_regex = True
    )
        split_docs = text_splitter.split_documents(docs)

        
        vectorstore = Chroma.from_documents(
            documents=split_docs, embedding=embeddings, persist_directory=vectorstore_path
        )
        
    
    retriever=vectorstore.as_retriever(search_type = search_type, search_kwargs={"k": k})

    
    

    return retriever






def main():


    

    st.set_page_config(page_title="Chat with multiple Lithuanian Law Documents: ",
                       page_icon=":books:")
    

    st.header("Chat with multiple Lithuanian Law Documents:" ":books:")
    
    st.markdown("###### Hi, I am Birute (Powered by gemma-2-2b-it-Q8 model), chat assistant, based on republic of Lithuania law documents. You can choose below information retrieval type and how many documents you want to be retrieved.")
    st.markdown("Available Documents: LR_Civil_Code_2022, LR_Constitution_2022, LR_Criminal_Code_2018, LR_Criminal_Procedure_code_2022,LR_Labour_code_2010. P.S it's a shame that there are no newest documents translations into English... ")

    if "messages" not in st.session_state:
        st.session_state["messages"] = [
        {"role": "assistant", "content": "Hi, I'm a chatbot who is  based on respublic of Lithuania law documents. How can I help you?"}
    ]


    search_type = st.selectbox(
        "Choose search type. Options are [Max marginal relevance search (similarity) , Similarity search (similarity). Default value (similarity)]", 
        options=["mmr", "similarity"], 
        index=1  
    )

    k = st.select_slider(
        "Select amount of documents to be retrieved. Default value (5): ", 
        options=list(range(2, 16)), 
        value=4  
    )
    retriever = create_retriever_from_chroma(vectorstore_path="docs/chroma/", search_type=search_type, k=k, chunk_size=350, chunk_overlap=30)

    
    
    rag_chain = create_conversational_rag_chain(retriever)

    
    if user_question := st.text_input("Ask a question about your documents:"):
        handle_userinput(user_question,retriever,rag_chain)
        
    
    
    
    
 
    
    

    


def handle_userinput(user_question,retriever,rag_chain):
    st.session_state.messages.append({"role": "user", "content": user_question})
    st.chat_message("user").write(user_question)
    docs = retriever.get_relevant_documents(user_question)

    with st.sidebar:
        st.subheader("Your documents")
        with st.spinner("Processing"):
            for doc in docs:
                st.write(f"Document: {doc}")
    
    doc_txt = [doc.page_content for doc in docs]
    
    
    response = rag_chain.invoke({"context": doc_txt, "question": user_question})
    st.session_state.messages.append({"role": "assistant", "content": response})
    st.chat_message("assistant").write(response)
    

                
            



def create_conversational_rag_chain(retriever):
    
    

    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

    llm = ChatLlamaCpp(
        model_path = "gemma-2-2b-it-Q8_0.gguf",
        seed = 41,
        n_gpu_layers=0,
        temperature=0.0,
        n_ctx=25000,
        n_batch=2000,
        max_tokens=250,
        repeat_penalty=1.7,
        last_n_tokens_size = 250,
        callback_manager=callback_manager,
        verbose=False,
    )

   
    
    template = """Answer the question, based only on the following context:
    {context}. Be consise. "Avoid naming. Contextualize your answer.
    Question: {question}
    """
    
    prompt = ChatPromptTemplate.from_template(template)
    



    rag_chain = prompt | llm | StrOutputParser()


    return rag_chain


 

if __name__ == "__main__":
    main()