Spaces:
Running
Running
disable conversational memory with zephyr
Browse files- streamlit_app.py +12 -7
streamlit_app.py
CHANGED
@@ -54,7 +54,7 @@ if 'uploaded' not in st.session_state:
|
|
54 |
st.session_state['uploaded'] = False
|
55 |
|
56 |
if 'memory' not in st.session_state:
|
57 |
-
st.session_state['memory'] =
|
58 |
|
59 |
if 'binary' not in st.session_state:
|
60 |
st.session_state['binary'] = None
|
@@ -117,12 +117,14 @@ def clear_memory():
|
|
117 |
def init_qa(model, api_key=None):
|
118 |
## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
|
119 |
if model == 'chatgpt-3.5-turbo':
|
|
|
120 |
if api_key:
|
121 |
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
|
122 |
temperature=0,
|
123 |
openai_api_key=api_key,
|
124 |
frequency_penalty=0.1)
|
125 |
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
|
|
|
126 |
else:
|
127 |
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
|
128 |
temperature=0,
|
@@ -134,11 +136,13 @@ def init_qa(model, api_key=None):
|
|
134 |
model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
|
135 |
embeddings = HuggingFaceEmbeddings(
|
136 |
model_name="all-MiniLM-L6-v2")
|
|
|
137 |
|
138 |
elif model == 'zephyr-7b-beta':
|
139 |
chat = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta",
|
140 |
model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
|
141 |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
|
|
142 |
else:
|
143 |
st.error("The model was not loaded properly. Try reloading. ")
|
144 |
st.stop()
|
@@ -255,7 +259,8 @@ with st.sidebar:
|
|
255 |
'Reset chat memory.',
|
256 |
key="reset-memory-button",
|
257 |
on_click=clear_memory,
|
258 |
-
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages."
|
|
|
259 |
|
260 |
left_column, right_column = st.columns([1, 1])
|
261 |
|
@@ -267,8 +272,8 @@ with right_column:
|
|
267 |
":warning: Do not upload sensitive data. We **temporarily** store text from the uploaded PDF documents solely for the purpose of processing your request, and we **do not assume responsibility** for any subsequent use or handling of the data submitted to third parties LLMs.")
|
268 |
|
269 |
uploaded_file = st.file_uploader("Upload an article",
|
270 |
-
|
271 |
-
|
272 |
disabled=st.session_state['model'] is not None and st.session_state['model'] not in
|
273 |
st.session_state['api_keys'],
|
274 |
help="The full-text is extracted using Grobid. ")
|
@@ -335,8 +340,8 @@ if uploaded_file and not st.session_state.loaded_embeddings:
|
|
335 |
|
336 |
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
|
337 |
chunk_size=chunk_size,
|
338 |
-
|
339 |
-
|
340 |
st.session_state['loaded_embeddings'] = True
|
341 |
st.session_state.messages = []
|
342 |
|
@@ -389,7 +394,7 @@ with right_column:
|
|
389 |
elif mode == "LLM":
|
390 |
with st.spinner("Generating response..."):
|
391 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
392 |
-
|
393 |
|
394 |
if not text_response:
|
395 |
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
|
|
|
54 |
st.session_state['uploaded'] = False
|
55 |
|
56 |
if 'memory' not in st.session_state:
|
57 |
+
st.session_state['memory'] = None
|
58 |
|
59 |
if 'binary' not in st.session_state:
|
60 |
st.session_state['binary'] = None
|
|
|
117 |
def init_qa(model, api_key=None):
|
118 |
## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
|
119 |
if model == 'chatgpt-3.5-turbo':
|
120 |
+
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
|
121 |
if api_key:
|
122 |
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
|
123 |
temperature=0,
|
124 |
openai_api_key=api_key,
|
125 |
frequency_penalty=0.1)
|
126 |
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
|
127 |
+
|
128 |
else:
|
129 |
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
|
130 |
temperature=0,
|
|
|
136 |
model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
|
137 |
embeddings = HuggingFaceEmbeddings(
|
138 |
model_name="all-MiniLM-L6-v2")
|
139 |
+
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
|
140 |
|
141 |
elif model == 'zephyr-7b-beta':
|
142 |
chat = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta",
|
143 |
model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
|
144 |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
145 |
+
st.session_state['memory'] = None
|
146 |
else:
|
147 |
st.error("The model was not loaded properly. Try reloading. ")
|
148 |
st.stop()
|
|
|
259 |
'Reset chat memory.',
|
260 |
key="reset-memory-button",
|
261 |
on_click=clear_memory,
|
262 |
+
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
|
263 |
+
disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)
|
264 |
|
265 |
left_column, right_column = st.columns([1, 1])
|
266 |
|
|
|
272 |
":warning: Do not upload sensitive data. We **temporarily** store text from the uploaded PDF documents solely for the purpose of processing your request, and we **do not assume responsibility** for any subsequent use or handling of the data submitted to third parties LLMs.")
|
273 |
|
274 |
uploaded_file = st.file_uploader("Upload an article",
|
275 |
+
type=("pdf", "txt"),
|
276 |
+
on_change=new_file,
|
277 |
disabled=st.session_state['model'] is not None and st.session_state['model'] not in
|
278 |
st.session_state['api_keys'],
|
279 |
help="The full-text is extracted using Grobid. ")
|
|
|
340 |
|
341 |
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
|
342 |
chunk_size=chunk_size,
|
343 |
+
perc_overlap=0.1,
|
344 |
+
include_biblio=True)
|
345 |
st.session_state['loaded_embeddings'] = True
|
346 |
st.session_state.messages = []
|
347 |
|
|
|
394 |
elif mode == "LLM":
|
395 |
with st.spinner("Generating response..."):
|
396 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
397 |
+
context_size=context_size)
|
398 |
|
399 |
if not text_response:
|
400 |
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
|