He Bo commited on
Commit
8aa0a12
1 Parent(s): 6ebdb42
Files changed (1) hide show
  1. app.py +36 -22
app.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  import gradio as gr
4
  from datetime import datetime
5
 
6
- invoke_url = "https://es9ke27g82.execute-api.us-west-2.amazonaws.com/prod"
7
  api = invoke_url + '/langchain_processor_qa?query='
8
 
9
  # chinese_index = "smart_search_qa_test_0614_wuyue_2"
@@ -63,7 +63,7 @@ Please automatically generate as many questions as possible based on this manual
63
  """
64
 
65
 
66
- def get_answer(question,session_id,language,prompt,index,top_k,temperature):
67
 
68
  if len(question) > 0:
69
  url = api + question
@@ -104,19 +104,28 @@ def get_answer(question,session_id,language,prompt,index,top_k,temperature):
104
  if len(prompt) > 0:
105
  url += ('&prompt='+prompt)
106
 
107
-
108
- if len(index) > 0:
109
- url += ('&index='+index)
110
- else:
111
- if language.find("chinese") >= 0 and len(chinese_index) >0:
112
- url += ('&index='+chinese_index)
113
- elif language == "english" and len(english_index) >0:
114
- url += ('&index='+english_index)
115
-
 
 
 
 
 
116
  if int(top_k) > 0:
117
  url += ('&top_k='+str(top_k))
118
 
119
  url += ('&temperature='+str(temperature))
 
 
 
 
120
  print("url:",url)
121
 
122
  now1 = datetime.now()#begin time
@@ -144,31 +153,35 @@ def get_answer(question,session_id,language,prompt,index,top_k,temperature):
144
  _id = "num:" + str(item['id'])
145
  source = "source:" + item['source']
146
  score = "score:" + str(item['score'])
147
- sentence = "sentence" + item['sentence']
148
  paragraph = "paragraph:" + item['paragraph']
149
  source_str += (_id + " " + source + " " + score + '\n')
150
- source_str += sentence + '\n'
151
  source_str += paragraph + '\n\n'
152
 
153
  confidence = ""
154
- query_doc_scores = ''
155
  if 'query_doc_scores' in result.keys():
156
- query_doc_scores = result['query_doc_scores']
157
- confidence += ("query_doc_scores:" + str(query_doc_scores) + '\n')
 
158
 
159
- qa_relate_score = ''
160
  if 'qa_relate_score' in result.keys():
161
  qa_relate_score = result['qa_relate_score']
 
162
  confidence += ("qa_relate_score:" + str(qa_relate_score) + '\n')
163
 
164
- answer_relate_scores = ''
165
  if 'answer_relate_scores' in result.keys():
166
- answer_relate_scores = result['answer_relate_scores']
167
- confidence += ("answer_relate_scores:" + str(answer_relate_scores) + '\n')
 
168
 
169
- list_overlap_score = ''
170
  if 'list_overlap_score' in result.keys():
171
  list_overlap_score = result['list_overlap_score']
 
172
  confidence += ("list_overlap_score:" + str(list_overlap_score) + '\n')
173
 
174
 
@@ -230,6 +243,7 @@ with demo:
230
  qa_language_radio = gr.Radio(["chinese-glm","chinese-glm2", "english"],value="chinese-glm",label="Language")
231
  # qa_llm_radio = gr.Radio(["p3-8x", "g4dn-8x"],value="p3-8x",label="Chinese llm instance")
232
  qa_prompt_textbox = gr.Textbox(label="Prompt( must include {context} and {question} )",placeholder=chinese_prompt,lines=2)
 
233
  qa_index_textbox = gr.Textbox(label="Index")
234
  qa_top_k_slider = gr.Slider(label="Top_k of source text to LLM",value=1, minimum=1, maximum=4, step=1)
235
 
@@ -253,7 +267,7 @@ with demo:
253
  text_output = gr.Textbox()
254
 
255
 
256
- qa_button.click(get_answer, inputs=[query_textbox,session_id_textbox,qa_language_radio,qa_prompt_textbox,qa_index_textbox,qa_top_k_slider,temperature_slider], outputs=qa_output)
257
  summarize_button.click(get_summarize, inputs=[text_input,sm_language_radio,sm_prompt_textbox], outputs=text_output)
258
 
259
  demo.launch()
 
3
  import gradio as gr
4
  from datetime import datetime
5
 
6
+ invoke_url = "https://02u4taf9pf.execute-api.us-west-2.amazonaws.com/prod"
7
  api = invoke_url + '/langchain_processor_qa?query='
8
 
9
  # chinese_index = "smart_search_qa_test_0614_wuyue_2"
 
63
  """
64
 
65
 
66
+ def get_answer(question,session_id,language,prompt,search_engine,index,top_k,temperature):
67
 
68
  if len(question) > 0:
69
  url = api + question
 
104
  if len(prompt) > 0:
105
  url += ('&prompt='+prompt)
106
 
107
+ if search_engine == "OpenSearch":
108
+ url += ('&search_engine=opensearch')
109
+ if len(index) > 0:
110
+ url += ('&index='+index)
111
+ else:
112
+ if language.find("chinese") >= 0 and len(chinese_index) >0:
113
+ url += ('&index='+chinese_index)
114
+ elif language == "english" and len(english_index) >0:
115
+ url += ('&index='+english_index)
116
+ elif search_engine == "Kendra":
117
+ url += ('&search_engine=kendra')
118
+ if len(index) > 0:
119
+ url += ('&kendra_index_id='+index)
120
+
121
  if int(top_k) > 0:
122
  url += ('&top_k='+str(top_k))
123
 
124
  url += ('&temperature='+str(temperature))
125
+ url += ('&cal_qa_relate_score=true')
126
+ url += ('&cal_answer_relate_scores=true')
127
+ url += ('&cal_list_overlap_score=true')
128
+
129
  print("url:",url)
130
 
131
  now1 = datetime.now()#begin time
 
153
  _id = "num:" + str(item['id'])
154
  source = "source:" + item['source']
155
  score = "score:" + str(item['score'])
156
+ sentence = "sentence:" + item['sentence']
157
  paragraph = "paragraph:" + item['paragraph']
158
  source_str += (_id + " " + source + " " + score + '\n')
159
+ # source_str += sentence + '\n'
160
  source_str += paragraph + '\n\n'
161
 
162
  confidence = ""
163
+ query_doc_scores = []
164
  if 'query_doc_scores' in result.keys():
165
+ query_doc_scores = list(result['query_doc_scores'])
166
+ if len(query_doc_scores) > 0:
167
+ confidence += ("query_doc_scores:" + str(float(query_doc_scores[0])) + '\n')
168
 
169
+ qa_relate_score = 0
170
  if 'qa_relate_score' in result.keys():
171
  qa_relate_score = result['qa_relate_score']
172
+ if float(qa_relate_score) > 0:
173
  confidence += ("qa_relate_score:" + str(qa_relate_score) + '\n')
174
 
175
+ answer_relate_scores = []
176
  if 'answer_relate_scores' in result.keys():
177
+ answer_relate_scores = list(result['answer_relate_scores'])
178
+ if len(answer_relate_scores) > 0:
179
+ confidence += ("answer_relate_scores:" + str(float(answer_relate_scores[0])) + '\n')
180
 
181
+ list_overlap_score = 0
182
  if 'list_overlap_score' in result.keys():
183
  list_overlap_score = result['list_overlap_score']
184
+ if float(list_overlap_score) > 0:
185
  confidence += ("list_overlap_score:" + str(list_overlap_score) + '\n')
186
 
187
 
 
243
  qa_language_radio = gr.Radio(["chinese-glm","chinese-glm2", "english"],value="chinese-glm",label="Language")
244
  # qa_llm_radio = gr.Radio(["p3-8x", "g4dn-8x"],value="p3-8x",label="Chinese llm instance")
245
  qa_prompt_textbox = gr.Textbox(label="Prompt( must include {context} and {question} )",placeholder=chinese_prompt,lines=2)
246
+ qa_search_engine_radio = gr.Radio(["OpenSearch","Kendra"],value="OpenSearch",label="Search engine")
247
  qa_index_textbox = gr.Textbox(label="Index")
248
  qa_top_k_slider = gr.Slider(label="Top_k of source text to LLM",value=1, minimum=1, maximum=4, step=1)
249
 
 
267
  text_output = gr.Textbox()
268
 
269
 
270
+ qa_button.click(get_answer, inputs=[query_textbox,session_id_textbox,qa_language_radio,qa_prompt_textbox,qa_search_engine_radio,qa_index_textbox,qa_top_k_slider,temperature_slider], outputs=qa_output)
271
  summarize_button.click(get_summarize, inputs=[text_input,sm_language_radio,sm_prompt_textbox], outputs=text_output)
272
 
273
  demo.launch()