|
import requests |
|
import json |
|
import gradio as gr |
|
from datetime import datetime |
|
|
|
invoke_url = "https://02u4taf9pf.execute-api.us-west-2.amazonaws.com/prod" |
|
api = invoke_url + '/langchain_processor_qa?query=' |
|
|
|
|
|
|
|
chinese_index = "smart_search_qa_demo_0620_cn" |
|
english_index = "smart_search_qa_demo_0618_en_2" |
|
|
|
chinese_prompt = """基于以下已知信息,简洁和专业的来回答用户的问题,并告知是依据哪些信息来进行回答的。 |
|
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 |
|
|
|
问题: {question} |
|
========= |
|
{context} |
|
========= |
|
答案:""" |
|
|
|
english_prompt = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
{context} |
|
|
|
Question: {question} |
|
Answer:""" |
|
|
|
|
|
zh_prompt_template = """ |
|
如下三个反括号中是aws的产品文档片段 |
|
``` |
|
{text} |
|
``` |
|
请基于这些文档片段自动生成尽可能多的问题以及对应答案, 尽可能详细全面, 并且遵循如下规则: |
|
1. "aws"需要一直被包含在Question中 |
|
2. 答案部分的内容必须为上述aws的产品文档片段的内容摘要 |
|
3. 问题部分需要以"Question:"开始 |
|
4. 答案部分需要以"Answer:"开始 |
|
""" |
|
|
|
en_prompt_template = """ |
|
Here is one page of aws's product document |
|
``` |
|
{text} |
|
``` |
|
Please automatically generate FAQs based on these document fragments, with answers that should not exceed 50 words as much as possible, and follow the following rules: |
|
1. 'aws' needs to be included in the question |
|
2. The content of the answer section must be a summary of the content of the above document fragments |
|
|
|
The Question and Answer are: |
|
""" |
|
|
|
EN_SUMMARIZE_PROMPT_TEMPLATE = """ |
|
Here is one page of aws's manual document |
|
``` |
|
{text} |
|
``` |
|
Please automatically generate as many questions as possible based on this manual document, and follow these rules: |
|
1. "aws" should be contained in every question |
|
2. questions start with "Question:" |
|
3. answers begin with "Answer:" |
|
""" |
|
|
|
|
|
def get_answer(question,session_id,language,prompt,search_engine,index,top_k,temperature): |
|
|
|
if len(question) > 0: |
|
url = api + question |
|
else: |
|
url = api + "hello" |
|
|
|
|
|
|
|
|
|
|
|
|
|
task = 'qa' |
|
url += ('&task='+task) |
|
|
|
if language == "english": |
|
url += '&language=english' |
|
url += ('&embedding_endpoint_name=pytorch-inference-all-minilm-l6-v2') |
|
url += ('&llm_embedding_name=pytorch-inference-vicuna-p3-2x') |
|
elif language == "chinese-llm-v1": |
|
url += '&language=chinese' |
|
url += ('&embedding_endpoint_name=huggingface-inference-text2vec-base-chinese-v1') |
|
url += ('&llm_embedding_name=pytorch-inference-chatglm-v1') |
|
|
|
elif language == "chinese-llm-v2": |
|
url += '&language=chinese' |
|
url += ('&embedding_endpoint_name=huggingface-inference-text2vec-base-chinese-v1') |
|
url += ('&llm_embedding_name=pytorch-inference-chatglm2-g5-2x') |
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(session_id) > 0: |
|
url += ('&session_id='+session_id) |
|
|
|
|
|
if len(prompt) > 0: |
|
url += ('&prompt='+prompt) |
|
|
|
if search_engine == "OpenSearch": |
|
url += ('&search_engine=opensearch') |
|
if len(index) > 0: |
|
url += ('&index='+index) |
|
else: |
|
if language.find("chinese") >= 0 and len(chinese_index) >0: |
|
url += ('&index='+chinese_index) |
|
elif language == "english" and len(english_index) >0: |
|
url += ('&index='+english_index) |
|
elif search_engine == "Kendra": |
|
url += ('&search_engine=kendra') |
|
if len(index) > 0: |
|
url += ('&kendra_index_id='+index) |
|
|
|
if int(top_k) > 0: |
|
url += ('&top_k='+str(top_k)) |
|
|
|
url += ('&temperature='+str(temperature)) |
|
url += ('&cal_qa_relate_score=true') |
|
url += ('&cal_answer_relate_scores=true') |
|
url += ('&cal_list_overlap_score=true') |
|
|
|
print("url:",url) |
|
|
|
now1 = datetime.now() |
|
response = requests.get(url) |
|
now2 = datetime.now() |
|
request_time = now2-now1 |
|
print("request takes time:",request_time) |
|
|
|
result = response.text |
|
|
|
result = json.loads(result) |
|
print('result:',result) |
|
|
|
answer = result['suggestion_answer'] |
|
source_list = [] |
|
if 'source_list' in result.keys(): |
|
source_list = result['source_list'] |
|
|
|
print("answer:",answer) |
|
|
|
source_str = "" |
|
for i in range(len(source_list)): |
|
item = source_list[i] |
|
print('item:',item) |
|
_id = "num:" + str(item['id']) |
|
source = "source:" + item['source'] |
|
score = "score:" + str(item['score']) |
|
sentence = "sentence:" + item['sentence'] |
|
paragraph = "paragraph:" + item['paragraph'] |
|
source_str += (_id + " " + source + " " + score + '\n') |
|
|
|
source_str += paragraph + '\n\n' |
|
|
|
confidence = "" |
|
query_doc_scores = [] |
|
if 'query_doc_scores' in result.keys(): |
|
query_doc_scores = list(result['query_doc_scores']) |
|
if len(query_doc_scores) > 0: |
|
confidence += ("query_doc_scores:" + str(query_doc_scores) + '\n') |
|
|
|
qa_relate_score = 0 |
|
if 'qa_relate_score' in result.keys(): |
|
qa_relate_score = result['qa_relate_score'] |
|
if float(qa_relate_score) > 0: |
|
confidence += ("qa_relate_score:" + str(qa_relate_score) + '\n') |
|
|
|
answer_relate_scores = [] |
|
if 'answer_relate_scores' in result.keys(): |
|
answer_relate_scores = list(result['answer_relate_scores']) |
|
if len(answer_relate_scores) > 0: |
|
confidence += ("answer_relate_scores:" + str(answer_relate_scores) + '\n') |
|
|
|
list_overlap_score = 0 |
|
if 'list_overlap_score' in result.keys(): |
|
list_overlap_score = result['list_overlap_score'] |
|
if float(list_overlap_score) > 0: |
|
confidence += ("list_overlap_score:" + str(list_overlap_score) + '\n') |
|
|
|
|
|
return answer,confidence,source_str,url,request_time |
|
|
|
|
|
def get_summarize(texts,language,prompt): |
|
|
|
url = api + texts |
|
url += '&task=summarize' |
|
|
|
if language == "english": |
|
url += '&language=english' |
|
url += ('&embedding_endpoint_name=pytorch-inference-all-minilm-l6-v2') |
|
url += ('&llm_embedding_name=pytorch-inference-vicuna-v1-1-b') |
|
|
|
|
|
elif language == "chinese": |
|
url += '&language=chinese' |
|
url += ('&embedding_endpoint_name=huggingface-inference-text2vec-base-chinese-v1') |
|
|
|
url += ('&llm_embedding_name=pytorch-inference-chatglm2-g5-2x') |
|
|
|
|
|
|
|
|
|
|
|
if len(prompt) > 0: |
|
url += ('&prompt='+prompt) |
|
|
|
print('url:',url) |
|
response = requests.get(url) |
|
result = response.text |
|
result = json.loads(result) |
|
print('result1:',result) |
|
|
|
answer = result['summarize'] |
|
|
|
if language == 'english' and answer.find('The Question and Answer are:') > 0: |
|
answer=answer.split('The Question and Answer are:')[-1].strip() |
|
|
|
return answer |
|
|
|
demo = gr.Blocks(title="亚马逊云科技智能问答解决方案指南") |
|
with demo: |
|
gr.Markdown( |
|
"# <center>AWS Intelligent Q&A Solution Guide" |
|
) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Question Answering"): |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
query_textbox = gr.Textbox(label="Query") |
|
session_id_textbox = gr.Textbox(label="Session ID") |
|
qa_button = gr.Button("Summit") |
|
|
|
qa_language_radio = gr.Radio(["chinese-llm-v1","chinese-llm-v2", "english"],value="chinese-llm-v1",label="Language") |
|
|
|
qa_prompt_textbox = gr.Textbox(label="Prompt( must include {context} and {question} )",placeholder=chinese_prompt,lines=2) |
|
qa_search_engine_radio = gr.Radio(["OpenSearch","Kendra"],value="OpenSearch",label="Search engine") |
|
qa_index_textbox = gr.Textbox(label="Index") |
|
qa_top_k_slider = gr.Slider(label="Top_k of source text to LLM",value=1, minimum=1, maximum=4, step=1) |
|
|
|
temperature_slider = gr.Slider(label="temperature for LLM",value=0.01, minimum=0.0, maximum=1, step=0.01) |
|
|
|
|
|
|
|
with gr.Column(): |
|
qa_output = [gr.outputs.Textbox(label="Answer"), gr.outputs.Textbox(label="Confidence"), gr.outputs.Textbox(label="Source"), gr.outputs.Textbox(label="Url"), gr.outputs.Textbox(label="Request time")] |
|
|
|
|
|
with gr.TabItem("Summarize"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox(label="Input texts",lines=4) |
|
summarize_button = gr.Button("Summit") |
|
sm_language_radio = gr.Radio(["chinese", "english"],value="chinese",label="Language") |
|
|
|
sm_prompt_textbox = gr.Textbox(label="Prompt",lines=4, placeholder=EN_SUMMARIZE_PROMPT_TEMPLATE) |
|
with gr.Column(): |
|
text_output = gr.Textbox() |
|
|
|
|
|
qa_button.click(get_answer, inputs=[query_textbox,session_id_textbox,qa_language_radio,qa_prompt_textbox,qa_search_engine_radio,qa_index_textbox,qa_top_k_slider,temperature_slider], outputs=qa_output) |
|
summarize_button.click(get_summarize, inputs=[text_input,sm_language_radio,sm_prompt_textbox], outputs=text_output) |
|
|
|
demo.launch() |
|
|
|
|