He Bo commited on
Commit
e0bec4f
1 Parent(s): ffa407c
Files changed (2) hide show
  1. app.py +217 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import gradio as gr
4
+ from datetime import datetime
5
+
6
+ invoke_url = "https://lvc8v7x2ak.execute-api.us-west-2.amazonaws.com/prod"
7
+ api = invoke_url + '/langchain_processor_qa?query='
8
+
9
+ # chinese_index = "smart_search_qa_test_0614_wuyue_2"
10
+ chinese_index = "smart_search_qa_demo_0618_cn_3"
11
+ english_index = "smart_search_qa_demo_0618_en_2"
12
+
13
+ chinese_prompt = """基于以下已知信息,简洁和专业的来回答用户的问题,并告知是依据哪些信息来进行回答的。
14
+ 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
15
+
16
+ 问题: {question}
17
+ =========
18
+ {context}
19
+ =========
20
+ 答案:"""
21
+
22
+ english_prompt = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
23
+ {context}
24
+
25
+ Question: {question}
26
+ Answer:"""
27
+
28
+
29
+ zh_prompt_template = """
30
+ 如下三个反括号中是aws的产品文档片段
31
+ ```
32
+ {text}
33
+ ```
34
+ 请基于这些文档片段自动生成尽可能多的问题以及对应答案, 尽可能详细全面, 并且遵循如下规则:
35
+ 1. "aws"需要一直被包含在Question中
36
+ 2. 答案部分的内容必须为上述aws的产品文档片段的内容摘要
37
+ 3. 问题部分需要以"Question:"开始
38
+ 4. 答案部分需要以"Answer:"开始
39
+ """
40
+
41
+ en_prompt_template = """
42
+ Here is one page of aws's product document
43
+ ```
44
+ {text}
45
+ ```
46
+ Please automatically generate FAQs based on these document fragments, with answers that should not exceed 50 words as much as possible, and follow the following rules:
47
+ 1. 'aws' needs to be included in the question
48
+ 2. The content of the answer section must be a summary of the content of the above document fragments
49
+
50
+ The Question and Answer are:
51
+ """
52
+
53
+
54
+ def get_answer(question,session_id,language,llm_instance,prompt,index,top_k):
55
+
56
+ if len(question) > 0:
57
+ url = api + question
58
+ else:
59
+ url = api + "hello"
60
+
61
+ # task='chat'
62
+ # if question.find('电商')>=0 or question.find('开店')>=0 or question.find('亚马逊')>=0:
63
+ # task = 'qa'
64
+ # url += ('&task='+task)
65
+
66
+
67
+ if language == "english":
68
+ url += '&language=english'
69
+ url += ('&embedding_endpoint_name=pytorch-inference-all-minilm-l6-v2')
70
+ url += ('&llm_embedding_name=pytorch-inference-vicuna-v1-1-b')
71
+ elif language == "chinese":
72
+ url += '&language=chinese'
73
+ url += ('&embedding_endpoint_name=huggingface-inference-text2vec-base-chinese-v1')
74
+ if llm_instance == 'p3-8x':
75
+ url += ('&llm_embedding_name=pytorch-inference-chatglm-v1-p3-8x')
76
+ elif llm_instance == 'g4dn-8x':
77
+ url += ('&llm_embedding_name=pytorch-inference-chatglm-v1-8x')
78
+
79
+ if len(session_id) > 0:
80
+ url += ('&session_id='+session_id)
81
+
82
+
83
+ if len(prompt) > 0:
84
+ url += ('&prompt='+prompt)
85
+
86
+
87
+ if len(index) > 0:
88
+ url += ('&index='+index)
89
+ else:
90
+ if language == "chinese" and len(chinese_index) >0:
91
+ url += ('&index='+chinese_index)
92
+ elif language == "english" and len(english_index) >0:
93
+ url += ('&index='+english_index)
94
+
95
+ if int(top_k) > 0:
96
+ url += ('&top_k='+str(top_k))
97
+
98
+ print("url:",url)
99
+
100
+ now1 = datetime.now()#begin time
101
+ response = requests.get(url)
102
+ now2 = datetime.now()#endtime
103
+ request_time = now2-now1
104
+ print("request takes time:",request_time)
105
+
106
+ result = response.text
107
+
108
+ result = json.loads(result)
109
+ print('result:',result)
110
+
111
+ answer = result['suggestion_answer']
112
+ source_list = []
113
+ if 'source_list' in result.keys():
114
+ source_list = result['source_list']
115
+
116
+ print("answer:",answer)
117
+
118
+ source_str = ""
119
+ for i in range(len(source_list)):
120
+ item = source_list[i]
121
+ print('item:',item)
122
+ _id = "num:" + str(item['id'])
123
+ source = "source:" + item['source']
124
+ score = "score:" + str(item['score'])
125
+ sentence = "sentence" + item['sentence']
126
+ paragraph = "paragraph:" + item['paragraph']
127
+ source_str += (_id + " " + source + " " + score + '\n')
128
+ source_str += sentence + '\n'
129
+ source_str += paragraph + '\n\n'
130
+
131
+ confidence = ""
132
+ if 'confidence' in result.keys():
133
+ confidence = result['confidence']
134
+
135
+ return answer,confidence,source_str,url,request_time
136
+
137
+
138
+ def get_summarize(texts,language,llm_instance,prompt):
139
+
140
+ url = api + texts
141
+ url += '&task=summarize'
142
+
143
+ if language == "english":
144
+ url += '&language=english'
145
+ url += ('&embedding_endpoint_name=pytorch-inference-all-minilm-l6-v2')
146
+ url += ('&llm_embedding_name=pytorch-inference-vicuna-v1-1-b')
147
+ # url += ('&prompt='+en_prompt_template)
148
+
149
+ elif language == "chinese":
150
+ url += '&language=chinese'
151
+ url += ('&embedding_endpoint_name=huggingface-inference-text2vec-base-chinese-v1')
152
+ # url += ('&prompt='+zh_prompt_template)
153
+ if llm_instance == '2x':
154
+ url += ('&llm_embedding_name=pytorch-inference-chatglm-v1')
155
+ elif llm_instance == '8x':
156
+ url += ('&llm_embedding_name=pytorch-inference-chatglm-v1-8x')
157
+
158
+ if len(prompt) > 0:
159
+ url += ('&prompt='+prompt)
160
+
161
+ print('url:',url)
162
+ response = requests.get(url)
163
+ result = response.text
164
+ result = json.loads(result)
165
+ print('result1:',result)
166
+
167
+ answer = result['summarize']
168
+
169
+ if language == 'english' and answer.find('The Question and Answer are:') > 0:
170
+ answer=answer.split('The Question and Answer are:')[-1].strip()
171
+
172
+ return answer
173
+
174
+ demo = gr.Blocks(title="亚马逊云科技智能问答解决方案指南")
175
+ with demo:
176
+ gr.Markdown(
177
+ "# <center>亚马逊云科技智能问答解决方案指南"
178
+ )
179
+
180
+ with gr.Tabs():
181
+ with gr.TabItem("Question Answering"):
182
+
183
+ with gr.Row():
184
+ with gr.Column():
185
+ query_textbox = gr.Textbox(label="Query")
186
+ session_id_textbox = gr.Textbox(label="Session ID")
187
+ qa_button = gr.Button("Summit")
188
+
189
+ qa_language_radio = gr.Radio(["chinese", "english"],value="chinese",label="Language")
190
+ qa_llm_radio = gr.Radio(["p3-8x", "g4dn-8x"],value="p3-8x",label="Chinese llm instance")
191
+ qa_prompt_textbox = gr.Textbox(label="Prompt( must include {context} and {question} )",lines=2)
192
+ qa_index_textbox = gr.Textbox(label="Index")
193
+ qa_top_k_slider = gr.Slider(label="Top_k of source text to LLM",value=1, minimum=1, maximum=4, step=1)
194
+
195
+ #language_radio.change(fn=change_prompt, inputs=language_radio, outputs=prompt_textbox)
196
+
197
+ with gr.Column():
198
+ qa_output = [gr.outputs.Textbox(label="Answer"), gr.outputs.Textbox(label="Confidence"), gr.outputs.Textbox(label="Source"), gr.outputs.Textbox(label="Url"), gr.outputs.Textbox(label="Request time")]
199
+
200
+
201
+ with gr.TabItem("Summarize"):
202
+ with gr.Row():
203
+ with gr.Column():
204
+ text_input = gr.Textbox(label="Input texts",lines=4)
205
+ summarize_button = gr.Button("Summit")
206
+ sm_language_radio = gr.Radio(["chinese", "english"],value="chinese",label="Language")
207
+ sm_llm_radio = gr.Radio(["2x", "8x"],value="2x",label="Chinese llm instance")
208
+ sm_prompt_textbox = gr.Textbox(label="Prompt",lines=4, placeholder=zh_prompt_template)
209
+ with gr.Column():
210
+ text_output = gr.Textbox()
211
+
212
+
213
+ qa_button.click(get_answer, inputs=[query_textbox,session_id_textbox,qa_language_radio,qa_llm_radio,qa_prompt_textbox,qa_index_textbox,qa_top_k_slider], outputs=qa_output)
214
+ summarize_button.click(get_summarize, inputs=[text_input,sm_language_radio,sm_llm_radio,sm_prompt_textbox], outputs=text_output)
215
+
216
+ demo.launch()
217
+ # smart_qa.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ requests