smart_qa / app.py
He Bo
update
4503d3c
raw
history blame
10.8 kB
import requests
import json
import gradio as gr
from datetime import datetime
invoke_url = "https://3e2paa86c4.execute-api.us-west-2.amazonaws.com/prod"
api = invoke_url + '/langchain_processor_qa?query='
# chinese_index = "smart_search_qa_test_0614_wuyue_2"
# chinese_index = "smart_search_qa_demo_0618_cn_3"
chinese_index = "smart_search_qa_demo_0620_cn"
english_index = "smart_search_qa_demo_0618_en_2"
chinese_prompt = """给定一个长文档和一个问题的以下提取部分,如果你不知道答案,就说你不知道。不要试图编造答案。用中文回答。
问题: {question}
=========
{context}
=========
答案:"""
english_prompt = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Answer:"""
zh_prompt_template = """
如下三个反括号中是aws的产品文档片段
```
{text}
```
请基于这些文档片段自动生成尽可能多的问题以及对应答案, 尽可能详细全面, 并且遵循如下规则:
1. "aws"需要一直被包含在Question中
2. 答案部分的内容必须为上述aws的产品文档片段的内容摘要
3. 问题部分需要以"Question:"开始
4. 答案部分需要以"Answer:"开始
"""
en_prompt_template = """
Here is one page of aws's product document
```
{text}
```
Please automatically generate FAQs based on these document fragments, with answers that should not exceed 50 words as much as possible, and follow the following rules:
1. 'aws' needs to be included in the question
2. The content of the answer section must be a summary of the content of the above document fragments
The Question and Answer are:
"""
EN_SUMMARIZE_PROMPT_TEMPLATE = """
Here is one page of aws's manual document
```
{text}
```
Please automatically generate as many questions as possible based on this manual document, and follow these rules:
1. "aws" should be contained in every question
2. questions start with "Question:"
3. answers begin with "Answer:"
"""
def get_answer(task_type,question,session_id,language,prompt,search_engine,index,top_k,score_type_checklist):
question=question.replace('AWS','亚马逊云科技').replace('aws','亚马逊云科技').replace('Aws','亚马逊云科技')
print('question:',question)
if len(question) > 0:
url = api + question
else:
url = api + "hello"
#task type: qa,chat
if task_type == "Knowledge base Q&A":
task = 'qa'
else:
task = 'chat'
url += ('&task='+task)
if language == "english":
url += '&language=english'
url += ('&embedding_endpoint_name=pytorch-inference-all-minilm-l6-v2')
url += ('&llm_embedding_name=pytorch-inference-vicuna-p3-2x')
elif language == "chinese":
url += '&language=chinese'
# url += ('&embedding_endpoint_name=huggingface-inference-m3e-base')
url += ('&embedding_endpoint_name=huggingface-inference-text2vec-base-chinese-v1')
url += ('&llm_embedding_name=pytorch-inference-chatglm2-g5-4x')
elif language == "chinese-tc":
url += '&language=chinese-tc'
# url += ('&embedding_endpoint_name=huggingface-inference-m3e-base')
url += ('&embedding_endpoint_name=huggingface-inference-text2vec-base-chinese-v1')
url += ('&llm_embedding_name=pytorch-inference-chatglm2-g5-4x')
if len(session_id) > 0:
url += ('&session_id='+session_id)
if len(prompt) > 0:
url += ('&prompt='+prompt)
if search_engine == "OpenSearch":
url += ('&search_engine=opensearch')
if len(index) > 0:
url += ('&index='+index)
else:
if language.find("chinese") >= 0 and len(chinese_index) >0:
url += ('&index='+chinese_index)
elif language == "english" and len(english_index) >0:
url += ('&index='+english_index)
elif search_engine == "Kendra":
url += ('&search_engine=kendra')
if len(index) > 0:
url += ('&kendra_index_id='+index)
if int(top_k) > 0:
url += ('&top_k='+str(top_k))
for score_type in score_type_checklist:
url += ('&cal_' + score_type +'=true')
print("url:",url)
now1 = datetime.now()#begin time
response = requests.get(url)
now2 = datetime.now()#endtime
request_time = now2-now1
print("request takes time:",request_time)
result = response.text
result = json.loads(result)
print('result:',result)
answer = result['suggestion_answer']
source_list = []
if 'source_list' in result.keys():
source_list = result['source_list']
print("answer:",answer)
source_str = ""
for i in range(len(source_list)):
item = source_list[i]
print('item:',item)
_id = "num:" + str(item['id'])
source = "source:" + item['source']
score = "score:" + str(item['score'])
sentence = "sentence:" + item['sentence']
paragraph = "paragraph:" + item['paragraph']
source_str += (_id + " " + source + " " + score + '\n')
# source_str += sentence + '\n'
source_str += paragraph + '\n\n'
confidence = ""
query_docs_score = -1
if 'query_docs_score' in result.keys():
query_docs_score = float(result['query_docs_score'])
if query_docs_score >= 0:
confidence += ("query_docs_score:" + str(query_docs_score) + '\n')
query_answer_score = -1
if 'query_answer_score' in result.keys():
query_answer_score = float(result['query_answer_score'])
if query_answer_score >= 0:
confidence += ("query_answer_score:" + str(query_answer_score) + '\n')
answer_docs_score = -1
if 'answer_docs_score' in result.keys():
answer_docs_score = float(result['answer_docs_score'])
if answer_docs_score >= 0:
confidence += ("answer_docs_score:" + str(answer_docs_score) + '\n')
docs_list_overlap_score = -1
if 'docs_list_overlap_score' in result.keys():
docs_list_overlap_score = float(result['docs_list_overlap_score'])
if docs_list_overlap_score >= 0:
confidence += ("docs_list_overlap_score:" + str(docs_list_overlap_score) + '\n')
return answer,confidence,source_str,url,request_time
def get_summarize(texts,language,prompt):
url = api + texts
url += '&task=summarize'
if language == "english":
url += '&language=english'
url += ('&embedding_endpoint_name=pytorch-inference-all-minilm-l6-v2')
url += ('&llm_embedding_name=pytorch-inference-vicuna-v1-1-b')
# url += ('&prompt='+en_prompt_template)
elif language == "chinese":
url += '&language=chinese'
url += ('&embedding_endpoint_name=huggingface-inference-text2vec-base-chinese-v1')
# url += ('&prompt='+zh_prompt_template)
url += ('&llm_embedding_name=pytorch-inference-chatglm2-g5-2x')
# if llm_instance == '2x':
# url += ('&llm_embedding_name=pytorch-inference-chatglm-v1')
# elif llm_instance == '8x':
# url += ('&llm_embedding_name=pytorch-inference-chatglm-v1-8x')
if len(prompt) > 0:
url += ('&prompt='+prompt)
print('url:',url)
response = requests.get(url)
result = response.text
result = json.loads(result)
print('result1:',result)
answer = result['summarize']
if language == 'english' and answer.find('The Question and Answer are:') > 0:
answer=answer.split('The Question and Answer are:')[-1].strip()
return answer
demo = gr.Blocks(title="亚马逊云科技智能问答解决方案指南")
with demo:
gr.Markdown(
"# <center>AWS Intelligent Q&A Solution Guide"
)
with gr.Tabs():
with gr.TabItem("Question Answering"):
with gr.Row():
with gr.Column():
qa_task_radio = gr.Radio(["Knowledge base Q&A","Chat"],value="Knowledge base Q&A",label="Task")
query_textbox = gr.Textbox(label="Query")
session_id_textbox = gr.Textbox(label="Session ID")
qa_button = gr.Button("Summit")
qa_language_radio = gr.Radio(["chinese","chinese-tc", "english"],value="chinese",label="Language")
# qa_llm_radio = gr.Radio(["p3-8x", "g4dn-8x"],value="p3-8x",label="Chinese llm instance")
qa_prompt_textbox = gr.Textbox(label="Prompt( must include {context} and {question} )",placeholder=chinese_prompt,lines=2)
qa_search_engine_radio = gr.Radio(["OpenSearch","Kendra"],value="OpenSearch",label="Search engine")
qa_index_textbox = gr.Textbox(label="OpenSearch index OR Kendra index id")
qa_top_k_slider = gr.Slider(label="Top_k of source text to LLM",value=1, minimum=1, maximum=4, step=1)
# temperature_slider = gr.Slider(label="temperature for LLM",value=0.01, minimum=0.0, maximum=1, step=0.01)
score_type_checklist = gr.CheckboxGroup(["query_answer_score", "answer_docs_score","docs_list_overlap_score"],value=[],label="Confidence score type")
#language_radio.change(fn=change_prompt, inputs=language_radio, outputs=prompt_textbox)
with gr.Column():
qa_output = [gr.outputs.Textbox(label="Answer"), gr.outputs.Textbox(label="Confidence"), gr.outputs.Textbox(label="Source"), gr.outputs.Textbox(label="Url"), gr.outputs.Textbox(label="Request time")]
with gr.TabItem("Summarize"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Input texts",lines=4)
summarize_button = gr.Button("Summit")
sm_language_radio = gr.Radio(["chinese", "english"],value="chinese",label="Language")
# sm_llm_radio = gr.Radio(["2x", "8x"],value="2x",label="Chinese llm instance")
sm_prompt_textbox = gr.Textbox(label="Prompt",lines=4, placeholder=EN_SUMMARIZE_PROMPT_TEMPLATE)
with gr.Column():
text_output = gr.Textbox()
qa_button.click(get_answer, inputs=[qa_task_radio,query_textbox,session_id_textbox,qa_language_radio,qa_prompt_textbox,qa_search_engine_radio,qa_index_textbox,qa_top_k_slider,score_type_checklist], outputs=qa_output)
summarize_button.click(get_summarize, inputs=[text_input,sm_language_radio,sm_prompt_textbox], outputs=text_output)
demo.launch()
# demo.launch(share=True)