He Bo
commited on
Commit
•
e0bec4f
1
Parent(s):
ffa407c
updata
Browse files- app.py +217 -0
- requirements.txt +1 -0
app.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
import gradio as gr
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
invoke_url = "https://lvc8v7x2ak.execute-api.us-west-2.amazonaws.com/prod"
|
7 |
+
api = invoke_url + '/langchain_processor_qa?query='
|
8 |
+
|
9 |
+
# chinese_index = "smart_search_qa_test_0614_wuyue_2"
|
10 |
+
chinese_index = "smart_search_qa_demo_0618_cn_3"
|
11 |
+
english_index = "smart_search_qa_demo_0618_en_2"
|
12 |
+
|
13 |
+
chinese_prompt = """基于以下已知信息,简洁和专业的来回答用户的问题,并告知是依据哪些信息来进行回答的。
|
14 |
+
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
15 |
+
|
16 |
+
问题: {question}
|
17 |
+
=========
|
18 |
+
{context}
|
19 |
+
=========
|
20 |
+
答案:"""
|
21 |
+
|
22 |
+
english_prompt = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
23 |
+
{context}
|
24 |
+
|
25 |
+
Question: {question}
|
26 |
+
Answer:"""
|
27 |
+
|
28 |
+
|
29 |
+
zh_prompt_template = """
|
30 |
+
如下三个反括号中是aws的产品文档片段
|
31 |
+
```
|
32 |
+
{text}
|
33 |
+
```
|
34 |
+
请基于这些文档片段自动生成尽可能多的问题以及对应答案, 尽可能详细全面, 并且遵循如下规则:
|
35 |
+
1. "aws"需要一直被包含在Question中
|
36 |
+
2. 答案部分的内容必须为上述aws的产品文档片段的内容摘要
|
37 |
+
3. 问题部分需要以"Question:"开始
|
38 |
+
4. 答案部分需要以"Answer:"开始
|
39 |
+
"""
|
40 |
+
|
41 |
+
en_prompt_template = """
|
42 |
+
Here is one page of aws's product document
|
43 |
+
```
|
44 |
+
{text}
|
45 |
+
```
|
46 |
+
Please automatically generate FAQs based on these document fragments, with answers that should not exceed 50 words as much as possible, and follow the following rules:
|
47 |
+
1. 'aws' needs to be included in the question
|
48 |
+
2. The content of the answer section must be a summary of the content of the above document fragments
|
49 |
+
|
50 |
+
The Question and Answer are:
|
51 |
+
"""
|
52 |
+
|
53 |
+
|
54 |
+
def get_answer(question,session_id,language,llm_instance,prompt,index,top_k):
|
55 |
+
|
56 |
+
if len(question) > 0:
|
57 |
+
url = api + question
|
58 |
+
else:
|
59 |
+
url = api + "hello"
|
60 |
+
|
61 |
+
# task='chat'
|
62 |
+
# if question.find('电商')>=0 or question.find('开店')>=0 or question.find('亚马逊')>=0:
|
63 |
+
# task = 'qa'
|
64 |
+
# url += ('&task='+task)
|
65 |
+
|
66 |
+
|
67 |
+
if language == "english":
|
68 |
+
url += '&language=english'
|
69 |
+
url += ('&embedding_endpoint_name=pytorch-inference-all-minilm-l6-v2')
|
70 |
+
url += ('&llm_embedding_name=pytorch-inference-vicuna-v1-1-b')
|
71 |
+
elif language == "chinese":
|
72 |
+
url += '&language=chinese'
|
73 |
+
url += ('&embedding_endpoint_name=huggingface-inference-text2vec-base-chinese-v1')
|
74 |
+
if llm_instance == 'p3-8x':
|
75 |
+
url += ('&llm_embedding_name=pytorch-inference-chatglm-v1-p3-8x')
|
76 |
+
elif llm_instance == 'g4dn-8x':
|
77 |
+
url += ('&llm_embedding_name=pytorch-inference-chatglm-v1-8x')
|
78 |
+
|
79 |
+
if len(session_id) > 0:
|
80 |
+
url += ('&session_id='+session_id)
|
81 |
+
|
82 |
+
|
83 |
+
if len(prompt) > 0:
|
84 |
+
url += ('&prompt='+prompt)
|
85 |
+
|
86 |
+
|
87 |
+
if len(index) > 0:
|
88 |
+
url += ('&index='+index)
|
89 |
+
else:
|
90 |
+
if language == "chinese" and len(chinese_index) >0:
|
91 |
+
url += ('&index='+chinese_index)
|
92 |
+
elif language == "english" and len(english_index) >0:
|
93 |
+
url += ('&index='+english_index)
|
94 |
+
|
95 |
+
if int(top_k) > 0:
|
96 |
+
url += ('&top_k='+str(top_k))
|
97 |
+
|
98 |
+
print("url:",url)
|
99 |
+
|
100 |
+
now1 = datetime.now()#begin time
|
101 |
+
response = requests.get(url)
|
102 |
+
now2 = datetime.now()#endtime
|
103 |
+
request_time = now2-now1
|
104 |
+
print("request takes time:",request_time)
|
105 |
+
|
106 |
+
result = response.text
|
107 |
+
|
108 |
+
result = json.loads(result)
|
109 |
+
print('result:',result)
|
110 |
+
|
111 |
+
answer = result['suggestion_answer']
|
112 |
+
source_list = []
|
113 |
+
if 'source_list' in result.keys():
|
114 |
+
source_list = result['source_list']
|
115 |
+
|
116 |
+
print("answer:",answer)
|
117 |
+
|
118 |
+
source_str = ""
|
119 |
+
for i in range(len(source_list)):
|
120 |
+
item = source_list[i]
|
121 |
+
print('item:',item)
|
122 |
+
_id = "num:" + str(item['id'])
|
123 |
+
source = "source:" + item['source']
|
124 |
+
score = "score:" + str(item['score'])
|
125 |
+
sentence = "sentence" + item['sentence']
|
126 |
+
paragraph = "paragraph:" + item['paragraph']
|
127 |
+
source_str += (_id + " " + source + " " + score + '\n')
|
128 |
+
source_str += sentence + '\n'
|
129 |
+
source_str += paragraph + '\n\n'
|
130 |
+
|
131 |
+
confidence = ""
|
132 |
+
if 'confidence' in result.keys():
|
133 |
+
confidence = result['confidence']
|
134 |
+
|
135 |
+
return answer,confidence,source_str,url,request_time
|
136 |
+
|
137 |
+
|
138 |
+
def get_summarize(texts,language,llm_instance,prompt):
|
139 |
+
|
140 |
+
url = api + texts
|
141 |
+
url += '&task=summarize'
|
142 |
+
|
143 |
+
if language == "english":
|
144 |
+
url += '&language=english'
|
145 |
+
url += ('&embedding_endpoint_name=pytorch-inference-all-minilm-l6-v2')
|
146 |
+
url += ('&llm_embedding_name=pytorch-inference-vicuna-v1-1-b')
|
147 |
+
# url += ('&prompt='+en_prompt_template)
|
148 |
+
|
149 |
+
elif language == "chinese":
|
150 |
+
url += '&language=chinese'
|
151 |
+
url += ('&embedding_endpoint_name=huggingface-inference-text2vec-base-chinese-v1')
|
152 |
+
# url += ('&prompt='+zh_prompt_template)
|
153 |
+
if llm_instance == '2x':
|
154 |
+
url += ('&llm_embedding_name=pytorch-inference-chatglm-v1')
|
155 |
+
elif llm_instance == '8x':
|
156 |
+
url += ('&llm_embedding_name=pytorch-inference-chatglm-v1-8x')
|
157 |
+
|
158 |
+
if len(prompt) > 0:
|
159 |
+
url += ('&prompt='+prompt)
|
160 |
+
|
161 |
+
print('url:',url)
|
162 |
+
response = requests.get(url)
|
163 |
+
result = response.text
|
164 |
+
result = json.loads(result)
|
165 |
+
print('result1:',result)
|
166 |
+
|
167 |
+
answer = result['summarize']
|
168 |
+
|
169 |
+
if language == 'english' and answer.find('The Question and Answer are:') > 0:
|
170 |
+
answer=answer.split('The Question and Answer are:')[-1].strip()
|
171 |
+
|
172 |
+
return answer
|
173 |
+
|
174 |
+
demo = gr.Blocks(title="亚马逊云科技智能问答解决方案指南")
|
175 |
+
with demo:
|
176 |
+
gr.Markdown(
|
177 |
+
"# <center>亚马逊云科技智能问答解决方案指南"
|
178 |
+
)
|
179 |
+
|
180 |
+
with gr.Tabs():
|
181 |
+
with gr.TabItem("Question Answering"):
|
182 |
+
|
183 |
+
with gr.Row():
|
184 |
+
with gr.Column():
|
185 |
+
query_textbox = gr.Textbox(label="Query")
|
186 |
+
session_id_textbox = gr.Textbox(label="Session ID")
|
187 |
+
qa_button = gr.Button("Summit")
|
188 |
+
|
189 |
+
qa_language_radio = gr.Radio(["chinese", "english"],value="chinese",label="Language")
|
190 |
+
qa_llm_radio = gr.Radio(["p3-8x", "g4dn-8x"],value="p3-8x",label="Chinese llm instance")
|
191 |
+
qa_prompt_textbox = gr.Textbox(label="Prompt( must include {context} and {question} )",lines=2)
|
192 |
+
qa_index_textbox = gr.Textbox(label="Index")
|
193 |
+
qa_top_k_slider = gr.Slider(label="Top_k of source text to LLM",value=1, minimum=1, maximum=4, step=1)
|
194 |
+
|
195 |
+
#language_radio.change(fn=change_prompt, inputs=language_radio, outputs=prompt_textbox)
|
196 |
+
|
197 |
+
with gr.Column():
|
198 |
+
qa_output = [gr.outputs.Textbox(label="Answer"), gr.outputs.Textbox(label="Confidence"), gr.outputs.Textbox(label="Source"), gr.outputs.Textbox(label="Url"), gr.outputs.Textbox(label="Request time")]
|
199 |
+
|
200 |
+
|
201 |
+
with gr.TabItem("Summarize"):
|
202 |
+
with gr.Row():
|
203 |
+
with gr.Column():
|
204 |
+
text_input = gr.Textbox(label="Input texts",lines=4)
|
205 |
+
summarize_button = gr.Button("Summit")
|
206 |
+
sm_language_radio = gr.Radio(["chinese", "english"],value="chinese",label="Language")
|
207 |
+
sm_llm_radio = gr.Radio(["2x", "8x"],value="2x",label="Chinese llm instance")
|
208 |
+
sm_prompt_textbox = gr.Textbox(label="Prompt",lines=4, placeholder=zh_prompt_template)
|
209 |
+
with gr.Column():
|
210 |
+
text_output = gr.Textbox()
|
211 |
+
|
212 |
+
|
213 |
+
qa_button.click(get_answer, inputs=[query_textbox,session_id_textbox,qa_language_radio,qa_llm_radio,qa_prompt_textbox,qa_index_textbox,qa_top_k_slider], outputs=qa_output)
|
214 |
+
summarize_button.click(get_summarize, inputs=[text_input,sm_language_radio,sm_llm_radio,sm_prompt_textbox], outputs=text_output)
|
215 |
+
|
216 |
+
demo.launch()
|
217 |
+
# smart_qa.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
requests
|