|
import gradio as gr |
|
import openai |
|
import re |
|
import threading |
|
import json |
|
import os |
|
from collections import Counter |
|
from llm_utils import * |
|
from utils import * |
|
from retrieval_utils import * |
|
|
|
openai.api_key = os.getenv("api_key") |
|
openai.api_base = os.getenv("api_base") |
|
|
|
COT_PROMPT = "Let's think step by step." |
|
DIRECT_ANS_PROMPT = "The answer is" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global lock |
|
lock = threading.Lock() |
|
|
|
def answer_extraction_prompt(datatype): |
|
if datatype == "commonsense-mc": |
|
ans_prompt = "\nTherefore, among A through E, the answer is" |
|
elif datatype == "commonsense-verify": |
|
ans_prompt = "\nTherefore, the answer (Yes or No) is" |
|
elif datatype == "arithmetic": |
|
ans_prompt = "\nTherefore, the answer (arabic numerals) is" |
|
elif datatype == "symbolic-letter": |
|
ans_prompt = "\nTherefore, the answer is" |
|
elif datatype == "symbolic-coin": |
|
ans_prompt = "\nTherefore, the answer (Yes or No) is" |
|
else: |
|
ans_prompt = "\nTherefore, the answer is" |
|
return ans_prompt |
|
|
|
|
|
def zero_shot(datatype, question, engine): |
|
ANS_EXTRACTION_PROMPT = answer_extraction_prompt(datatype) |
|
ANS_EXTRACTION_PROMPT = ANS_EXTRACTION_PROMPT.replace("\nTherefore, ", "") |
|
ANS_EXTRACTION_PROMPT = ANS_EXTRACTION_PROMPT[0].upper() + ANS_EXTRACTION_PROMPT[1:] |
|
input = "Q: " + question + "\n" + "A: " + ANS_EXTRACTION_PROMPT |
|
ans_response = decoder_for_gpt3(input, max_length=32, engine=engine) |
|
ans_response = answer_cleansing_zero_shot(datatype, ans_response) |
|
if ans_response == "": |
|
ans_response = "VOID" |
|
return ans_response |
|
|
|
|
|
|
|
def highlight_knowledge(entities, retrieved_knowledge): |
|
str_md = retrieved_knowledge |
|
for ent in entities: |
|
ent_md = {} |
|
m_pos = re.finditer(ent, retrieved_knowledge, re.IGNORECASE) |
|
for m in m_pos: |
|
s, e = m.start(), m.end() |
|
if retrieved_knowledge[s:e] not in ent_md.keys(): |
|
ent_ = retrieved_knowledge[s:e] |
|
ent_md[ent_] = '<span style="background-color: lightcoral"> **' + ent_ + '** </span>' |
|
for e_ori, e_md in ent_md.items(): |
|
print(e_ori) |
|
print(e_md) |
|
str_md = str_md.replace(e_ori, e_md) |
|
return str_md |
|
|
|
def zero_cot_consi(question, engine): |
|
input = "Q: " + question + "\n" + "A: " + COT_PROMPT |
|
cot_responses = decoder_for_gpt3_consistency(input,max_length=256, engine=engine) |
|
return cot_responses |
|
|
|
def auto_cot_consi(question, demo_text, engine): |
|
input = demo_text + "Q: " + question + "\n" + "A: " + COT_PROMPT |
|
cot_responses = decoder_for_gpt3_consistency(input,max_length=256, engine=engine) |
|
return cot_responses |
|
|
|
|
|
def cot_revision(datatype, question, ori_cots, knowledge, engine): |
|
ANS_EXTRACTION_PROMPT = answer_extraction_prompt(datatype) |
|
corrected_rationales = [] |
|
corrected_answers = [] |
|
correction_prompt = "Question: " + "[ " + question + "]\n" |
|
correction_prompt += "Knowledge: " + "[ " + knowledge + "]\n" |
|
for ori_r in ori_cots: |
|
cor_p = correction_prompt + "Original rationale: " + "[ " + ori_r + "]\n" |
|
cor_p += "With Knowledge given, output the revised rationale for Question in a precise and certain style by thinking step by step: " |
|
corrected_rationale = decoder_for_gpt3(cor_p,max_length=256, temperature=0.7, engine=engine) |
|
corrected_rationale = corrected_rationale.strip() |
|
corrected_rationales.append(corrected_rationale) |
|
input = "Q: " + question + "\n" + "A: " + corrected_rationale + ANS_EXTRACTION_PROMPT |
|
ans = decoder_for_gpt3(input, max_length=32, temperature=0.7, engine=engine) |
|
ans = answer_cleansing_zero_shot(datatype, ans) |
|
corrected_answers.append(ans) |
|
return corrected_rationales, corrected_answers |
|
|
|
|
|
def consistency(arr): |
|
len_ans = len(arr) |
|
arr_acounts = Counter(arr) |
|
ans_freq_tuple = arr_acounts.most_common(len_ans) |
|
most_frequent_item, _ = ans_freq_tuple[0] |
|
ans_dict = {} |
|
for ans_freq in ans_freq_tuple: |
|
ans, times = ans_freq |
|
ans_dict[ans] = times/len_ans |
|
return most_frequent_item, ans_dict |
|
|
|
|
|
|
|
def record_feedback(single_data, feedback, store_flag): |
|
global lock |
|
print(f"Logging feedback...") |
|
datatype = single_data['datatype'] |
|
data_dir = './data_pool/{dataname}_feedback'.format(dataname=datatype) |
|
|
|
lock.acquire() |
|
if store_flag: |
|
single_data.update({'feedback':feedback}) |
|
with open(data_dir, "a") as f: |
|
data_json = json.dumps(single_data) |
|
f.write(data_json + "\n") |
|
lock.release() |
|
print(f"Logging finished...") |
|
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ |
|
gr.update(value="😃 Thank you for your valuable feedback!") |
|
|
|
|
|
def record_feedback_agree(input_question, datatype, our_ans, zshot_ans, self_know, kb_know, refine_know, cor_ans, store_flag): |
|
single_data = { |
|
'question': input_question, 'datatype': datatype, 'zshot_ans': zshot_ans, |
|
'adapter_ans': our_ans, 'self_know': self_know, 'kb_know': kb_know, |
|
'refine_know': refine_know, 'cor_ans': cor_ans, 'feedback': ""} |
|
return record_feedback(single_data, 'agree', store_flag) |
|
def record_feedback_disagree(input_question, datatype, our_ans, zshot_ans, self_know, kb_know, refine_know, cor_ans, store_flag): |
|
single_data = { |
|
'question': input_question, 'datatype': datatype, 'zshot_ans': zshot_ans, |
|
'adapter_ans': our_ans, 'self_know': self_know, 'kb_know': kb_know, |
|
'refine_know': refine_know, 'cor_ans': cor_ans, 'feedback': ""} |
|
return record_feedback(single_data, "disagree", store_flag) |
|
def record_feedback_uncertain(input_question, datatype, our_ans, zshot_ans, self_know, kb_know, refine_know, cor_ans, store_flag): |
|
single_data = { |
|
'question': input_question, 'datatype': datatype, 'zshot_ans': zshot_ans, |
|
'adapter_ans': our_ans, 'self_know': self_know, 'kb_know': kb_know, |
|
'refine_know': refine_know, 'cor_ans': cor_ans, 'feedback': ""} |
|
return record_feedback(single_data, 'uncertain', store_flag) |
|
|
|
def reset(): |
|
return gr.update(value=""), gr.update(value=""), \ |
|
gr.update(visible=False), gr.update(value="", label=""), gr.update(value="", label=""), gr.update(value="", label=""), \ |
|
gr.update(value=""), gr.update(value=""), gr.update(value=""), gr.update(value="") |
|
|
|
|
|
def identify_type(question, engine): |
|
with open('./demos/type', 'r') as f: |
|
typedemo = f.read() |
|
typedemo += "Question: " + question + "\nOutput the Type, choosing from <'arithmetic','commonsense-mc','commonsense-verify','symbolic-coin', 'symbolic-letter'>: " |
|
response = decoder_for_gpt3(typedemo, 32, temperature=0, engine=engine) |
|
response = response.strip().lower() |
|
response = type_cleasing(response) |
|
return response |
|
|
|
def load_examples(datatype): |
|
return gr.update(examples=EXAMPLES[datatype]) |
|
|
|
|
|
def self_construction(datatype): |
|
if datatype == "arithmetic": |
|
fig_adr = './figs/multiarith.png' |
|
demo_path = './demos/multiarith' |
|
elif datatype == "commonsense-mc": |
|
fig_adr = './figs/commonsensqa.png' |
|
demo_path = './demos/commonsensqa' |
|
elif datatype == "commonsense-verify": |
|
fig_adr = './figs/strategyqa.png' |
|
demo_path = './demos/strategyqa' |
|
elif datatype == "symbolic-coin": |
|
fig_adr = './figs/coin_flip.png' |
|
demo_path = './demos/coin_flip' |
|
elif datatype == "symbolic-letter": |
|
fig_adr = './figs/last_letters.png' |
|
demo_path = './demos/last_letters' |
|
else: |
|
return gr.update(value="## 🔭 Self construction..."), gr.update(visible=False), \ |
|
gr.update(visible=True, value="UNDEFINED Scenario! We just employ the zero-shot setting."), gr.update(value=""), \ |
|
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |
|
|
|
|
|
|
|
x, z, y =[], [], [] |
|
with open(demo_path, encoding="utf-8") as f: |
|
json_data = json.load(f) |
|
json_data = json_data["demo"] |
|
for line in json_data: |
|
x.append(line["question"]) |
|
z.append(line["rationale"]) |
|
y.append(line["pred_ans"]) |
|
index_list = list(range(len(x))) |
|
|
|
demo_md, demo_text = "", "" |
|
for i in index_list: |
|
demo_text += x[i] + " " + z[i] + " " + \ |
|
DIRECT_ANS_PROMPT + " " + y[i] + ".\n\n" |
|
demo_md += '<span style="background-color: #E0A182">' + "Q: "+ '</span>' + x[i][3:-3] + \ |
|
"<br>" + '<span style="background-color: #DD97AF">' + "A: "+ '</span>' + z[i] + " " + \ |
|
DIRECT_ANS_PROMPT + " " + y[i] + ".\n\n" |
|
|
|
|
|
return gr.update(value="## 🔭 Self construction..."), gr.update(visible=True, label="Visualization of clustering", value=fig_adr), \ |
|
gr.update(visible=True, value=demo_md), gr.update(value=demo_text), \ |
|
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |
|
|
|
def self_retrieval(input_question, engine): |
|
entities, self_retrieve_knowledge, kb_retrieve_knowledge = retrieve_for_question(input_question, engine) |
|
|
|
entities_string = ", ".join(entities) |
|
retr_md = "### ENTITIES:" + "<br>" + "> "+ entities_string + "\n\n" |
|
retr_md += "### LLM-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities,self_retrieve_knowledge) + "\n\n" |
|
retr_md += "### KB-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities, kb_retrieve_knowledge) + "\n\n" |
|
|
|
return gr.update(value="## 📚 Self retrieval..."), gr.update(visible=True, label="", value='./figs/self-retrieval.png'), \ |
|
gr.update(value=retr_md), \ |
|
gr.update(value=entities_string), gr.update(value=self_retrieve_knowledge), gr.update(value=kb_retrieve_knowledge), \ |
|
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |
|
|
|
def self_refinement(input_question, entities, self_retrieve_knowledge, kb_retrieve_knowledge, engine): |
|
refine_knowledge = refine_for_question(input_question, engine, self_retrieve_knowledge, kb_retrieve_knowledge) |
|
|
|
retr_md = "### ENTITIES:" + "<br>" + "> " + entities + "\n\n" |
|
entities = entities.strip().strip('<p>').strip('</p>').split(", ") |
|
retr_md += "### LLM-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities, self_retrieve_knowledge) + "\n\n" |
|
retr_md += "### KB-KNOWLEDGE:" + "<br>" + "> " + highlight_knowledge(entities, kb_retrieve_knowledge) + "\n\n" |
|
refine_md = retr_md + "### REFINED-KNOWLEDGE:" + "<br>" + "> " |
|
refine_md += highlight_knowledge(entities, refine_knowledge) |
|
|
|
|
|
return gr.update(value="## 🪄 Self refinement..."), gr.update(visible=True, label="", value='./figs/self-refinement.png'), \ |
|
gr.update(value=refine_md), gr.update(value=refine_knowledge), \ |
|
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |
|
|
|
def self_revision(input_question, datatype, demo_text, refined_knowledge, engine): |
|
print(demo_text) |
|
print(refined_knowledge) |
|
ori_cots = auto_cot_consi(input_question, demo_text, engine) |
|
cor_cots, cor_ans = cot_revision(datatype, input_question, ori_cots, refined_knowledge, engine) |
|
cor_cots_md = "### Revised Rationales:" + "\n\n" |
|
for cor_cot in cor_cots: |
|
cor_cots_md += "> " + cor_cot + "\n\n" |
|
cor_ans = ", ".join(cor_ans) |
|
|
|
return gr.update(value="## 🔧 Self revision..."), gr.update(visible=True, label="", value='./figs/self-revision.png'), \ |
|
gr.update(value=cor_cots_md), gr.update(value=cor_ans), \ |
|
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) |
|
|
|
def self_consistency(cor_ans, datatype, question, engine): |
|
cor_ans = cor_ans.strip().split(", ") |
|
our_ans, ans_dict = consistency(cor_ans) |
|
zeroshot_ans = zero_shot(datatype, question, engine) |
|
|
|
return gr.update(value="## 🗳 Self consistency..."), gr.update(visible=True, label="", value='./figs/self-consistency.png'), \ |
|
gr.update(value=""), gr.update(value=ans_dict, visible=True), \ |
|
gr.update(visible=True, value=our_ans), gr.update(visible=True, value=zeroshot_ans), \ |
|
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \ |
|
gr.update(visible=True, value='We would appreciate it very much if you could share your feedback. ') |
|
|
|
|
|
def reset(): |
|
return gr.update(value=""), gr.update(value=""), gr.update(value=""), \ |
|
gr.update(visible=False), gr.update(value=""), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\ |
|
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value="") |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme="bethecloud/storj_theme", css="#process_btn {background-color:#8BA3C5}") as demo: |
|
gr.Markdown("# AuRoRA: Augmented Reasoning and Refining with Task-Adaptive Chain-of-Thought Prompting") |
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
input_question = gr.Textbox(placeholder="Input question here, or select an example from below.", label="Input Question",lines=2) |
|
store_flag = gr.Checkbox(label="Store data",value=True, interactive=True, info="If you agree to store data for research and development use:") |
|
single_data = gr.JSON(visible=False) |
|
with gr.Column(scale=3): |
|
engine = gr.Dropdown(choices=['gpt-3.5-turbo','gpt-3.5-turbo-instruct', 'text-davinci-002', 'text-curie-001', 'text-babbage-001', 'text-ada-001'], |
|
label="Engine", value="gpt-3.5-turbo-instruct", interactive=True, info="Choose the engine and have a try!") |
|
reset_btn = gr.Button(value='RESET') |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
type_btn = gr.Button(value="Self-identification", variant='primary', scale=1, elem_id="process_btn") |
|
with gr.Column(scale=3): |
|
datatype = gr.Dropdown(choices=['arithmetic','commonsense-mc','commonsense-verify','symbolic-letter','symbolic-coin','UNDEFINED'], |
|
label="Input Type", info="If you disagree with our output, please select manually.", scale=3) |
|
|
|
demo_text = gr.Textbox(visible=False) |
|
entities = gr.Textbox(visible=False) |
|
self_know = gr.Textbox(visible=False) |
|
kb_know = gr.Textbox(visible=False) |
|
refine_know = gr.Textbox(visible=False) |
|
cor_ans = gr.Textbox(visible=False) |
|
with gr.Row(): |
|
const_btn = gr.Button(value='Self-construction', variant='primary', elem_id="process_btn") |
|
retr_btn = gr.Button(value='Self-retrieval', variant='primary', elem_id="process_btn") |
|
refine_btn = gr.Button(value='Self-refinement', variant='primary', elem_id="process_btn") |
|
revis_btn = gr.Button(value='Self-revision', variant='primary', elem_id="process_btn") |
|
consis_btn = gr.Button(value='Self-consistency', variant='primary', elem_id="process_btn") |
|
|
|
sub_title = gr.Markdown() |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
plot = gr.Image(label="Visualization of clustering", visible=False) |
|
with gr.Column(scale=3): |
|
md = gr.Markdown() |
|
label = gr.Label(visible=False, label="Consistency Predictions") |
|
ans_ours = gr.Textbox(label="AuRoRA Answer",visible=False) |
|
ans_zeroshot = gr.Textbox(label="Zero-shot Answer", visible=False) |
|
with gr.Row(): |
|
feedback_agree = gr.Button(value='😊 Agree', variant='secondary', visible=False) |
|
feedback_disagree = gr.Button(value='🙁 Disagree', variant='secondary', visible=False) |
|
feedback_uncertain = gr.Button(value='🤔 Uncertain', variant='secondary', visible=False) |
|
feedback_ack = gr.Markdown(value='', visible=True, interactive=False) |
|
|
|
|
|
type_btn.click(identify_type, inputs=[input_question, engine], outputs=[datatype]) |
|
const_btn.click(self_construction, inputs=[datatype], outputs=[sub_title, plot, md, demo_text, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack]) |
|
retr_btn.click(self_retrieval, inputs=[input_question, engine], outputs=[sub_title, plot, md, entities, self_know, kb_know, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack]) |
|
refine_btn.click(self_refinement, inputs=[input_question, entities, self_know, kb_know, engine], outputs=[sub_title, plot, md, refine_know, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack]) |
|
revis_btn.click(self_revision, inputs=[input_question, datatype, demo_text, refine_know, engine], outputs=[sub_title, plot, md, cor_ans, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack]) |
|
consis_btn.click(self_consistency, inputs=[cor_ans, datatype, input_question, engine], outputs=[sub_title, plot, md, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack]) |
|
reset_btn.click(reset, inputs=[], outputs=[input_question, datatype, sub_title, plot, md, label, ans_ours, ans_zeroshot, feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack]) |
|
|
|
feedback_agree.click(record_feedback_agree, inputs=[input_question, datatype, ans_ours, ans_zeroshot, self_know, kb_know, refine_know, cor_ans ,store_flag], outputs=[feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack]) |
|
feedback_disagree.click(record_feedback_disagree, inputs=[input_question, datatype, ans_ours, ans_zeroshot, self_know, kb_know, refine_know, cor_ans ,store_flag], outputs=[feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack]) |
|
feedback_uncertain.click(record_feedback_uncertain, inputs=[input_question, datatype, ans_ours, ans_zeroshot, self_know, kb_know, refine_know, cor_ans ,store_flag], outputs=[feedback_agree, feedback_disagree, feedback_uncertain, feedback_ack]) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|