sivan22 commited on
Commit
13791ef
โ€ข
1 Parent(s): 6fb1bed
Files changed (5) hide show
  1. App.py +100 -0
  2. __init__.py +13 -0
  3. requirements.txt +9 -0
  4. run.bat +2 -0
  5. utils.py +28 -0
App.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit.logger import get_logger
3
+ import datasets
4
+ import pandas as pd
5
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_core.messages import HumanMessage, SystemMessage
9
+ from sentence_transformers import util
10
+
11
+
12
+
13
+ LOGGER = get_logger(__name__)
14
+
15
+
16
+ @st.cache_data
17
+ def get_df() ->object:
18
+ ds = datasets.load_dataset('sivan22/yalkut-yosef-embeddings')
19
+ df = pd.DataFrame.from_dict(ds['train'])
20
+ return df
21
+
22
+ @st.cache_resource
23
+ def get_model()->object:
24
+ model_name = "intfloat/multilingual-e5-large"
25
+ model_kwargs = {'device': 'cpu'} #'cpu' or 'cuda'
26
+ encode_kwargs = {'normalize_embeddings': False}
27
+ embeddings_model = HuggingFaceEmbeddings(
28
+ model_name=model_name,
29
+ model_kwargs=model_kwargs,
30
+ encode_kwargs=encode_kwargs
31
+ )
32
+ return embeddings_model
33
+
34
+ @st.cache_resource
35
+ def get_chat_api(api_key:str):
36
+ chat = ChatOpenAI(model="gpt-3.5-turbo-16k", api_key=api_key)
37
+ return chat
38
+
39
+
40
+ def get_results(embeddings_model,input,df,num_of_results) -> pd.DataFrame:
41
+ embeddings = embeddings_model.embed_query('query: '+ input)
42
+ df['similarity'] = df['embeddings'].apply(lambda x: util.dot_score(x,embeddings))
43
+ results = df.sort_values(by='similarity', ascending=False)
44
+ return results.head(num_of_results)
45
+
46
+ def get_llm_results(query,chat,results):
47
+
48
+ prompt_template = PromptTemplate.from_template(
49
+ """
50
+ the question is: {query}
51
+ the possible answers are:
52
+ {answers}
53
+
54
+ """ )
55
+
56
+ messages = [
57
+ SystemMessage(content="You're a helpful assistant. given a question, filter and sort the possible answers to the given question by relevancy, drop the irrelevant answers and return the results in the following JSON format (scores are between 0 and 1): {\"answer\": \"score\", \"answer\": \"score\"}. "),
58
+ HumanMessage(content=prompt_template.format(query=query, answers=str.join('\n', results['text'].head(10).tolist()))),
59
+ ]
60
+
61
+ response = chat.invoke(messages)
62
+ llm_results_df = pd.read_json(response.content, orient='index')
63
+ return llm_results_df
64
+
65
+
66
+
67
+ def run():
68
+
69
+ st.set_page_config(
70
+ page_title=" ื—ื™ืคื•ืฉ ืกืžื ื˜ื™ ื‘ื™ืœืงื•ื˜ ื™ื•ืกืฃ",
71
+ page_icon="๐Ÿ“š",
72
+ layout="wide",
73
+ initial_sidebar_state="expanded"
74
+ )
75
+
76
+ st.write("ื—ื™ืคื•ืฉ ื—ื›ื ื‘ืกืคืจ ื™ืœืงื•ื˜ ื™ื•ืกืฃ ืงื™ืฆื•ืจ ืฉื•ืœื—ืŸ ืขืจื•ืš")
77
+
78
+ embeddings_model = get_model()
79
+ df = get_df()
80
+
81
+ user_input = st.text_input('ื›ืชื•ื‘ ื›ืืŸ ืืช ืฉืืœืชืš', placeholder='ื›ืžื” ื ืจื•ืช ืžื“ืœื™ืงื™ื ื‘ื›ืœ ืœื™ืœื” ืžืœื™ืœื•ืช ื”ื—ื ื•ื›ื”')
82
+ num_of_results = st.sidebar.slider('ืžืกืคืจ ื”ืชื•ืฆืื•ืช ืฉื‘ืจืฆื•ื ืš ืœื”ืฆื™ื’:',1,25,5)
83
+ use_llm = st.sidebar.checkbox("ื”ืฉืชืžืฉ ื‘ืžื•ื“ืœ ืฉืคื” ื›ื“ื™ ืœืฉืคืจ ืชื•ืฆืื•ืช", False)
84
+ openAikey = st.sidebar.text_input("OpenAI API key", type="password")
85
+
86
+ if (st.button('ื—ืคืฉ') or user_input) and user_input!="":
87
+
88
+ results = get_results(embeddings_model,user_input,df,num_of_results)
89
+
90
+ if use_llm:
91
+ chat = get_chat_api(openAikey)
92
+ llm_results = get_llm_results(user_input,chat,results)
93
+ st.write(llm_results)
94
+
95
+ else:
96
+ st.write(results[['siman','sek','text']].head(10))
97
+
98
+
99
+ if __name__ == "__main__":
100
+ run()
__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ streamlit
3
+ torch
4
+ transformers
5
+ datasets
6
+ langchain_huggingface
7
+ langchain_openai
8
+ langchain
9
+ sentence_transformers
run.bat ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pip install -r requirements.txt
2
+ streamlit run app.py
utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import textwrap
17
+
18
+ import streamlit as st
19
+
20
+
21
+ def show_code(demo):
22
+ """Showing the code of the demo."""
23
+ show_code = st.sidebar.checkbox("Show code", True)
24
+ if show_code:
25
+ # Showing the code of the demo.
26
+ st.markdown("## Code")
27
+ sourcelines, _ = inspect.getsourcelines(demo)
28
+ st.code(textwrap.dedent("".join(sourcelines[1:])))