NinaAchache's picture
correcting path
a906c85
raw
history blame
3.8 kB
import gradio as gr
from transformers import pipeline
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import EmbeddingRetriever
import numpy as np
import openai
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
system_template = {
"role": "system",
"content": "You have been a climate change expert for 30 years. You answer questions about climate change in an educationnal and concise manner.",
}
document_store = FAISSDocumentStore.load(
index_path=f"./documents/climate_gpt.faiss",
config_path=f"./documents/climate_gpt.json",
)
dense = EmbeddingRetriever(
document_store=document_store,
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
model_format="sentence_transformers",
)
def is_climate_change_related(sentence: str) -> bool:
results = classifier(
sequences=sentence,
candidate_labels=["climate change related", "non climate change related"],
)
return results["labels"][np.argmax(results["scores"])] == "climate change related"
def make_pairs(lst):
"""from a list of even lenght, make tupple pairs"""
return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]
def gen_conv(query: str, history=[system_template], ipcc=True):
"""return (answer:str, history:list[dict], sources:str)"""
retrieve = ipcc and is_climate_change_related(query)
sources = ""
messages = history + [
{"role": "user", "content": query},
]
if retrieve:
docs = dense.retrieve(query=query, top_k=5)
sources = "\n\n".join(
["If relevant, use those extracts from IPCC reports in your answer"]
+ [
f"{d.meta['path']} Page {d.meta['page_id']} paragraph {d.meta['paragraph_id']}:\n{d.content}"
for d in docs
]
)
messages.append({"role": "system", "content": sources})
answer = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.2,
# max_tokens=200,
)["choices"][0]["message"]["content"]
if retrieve:
messages.pop()
answer = "(top 5 documents retrieved) " + answer
sources = "\n\n".join(
f"{d.meta['path']} Page {d.meta['page_id']} paragraph {d.meta['paragraph_id']}:\n{d.content[:100]} [...]"
for d in docs
)
messages.append({"role": "assistant", "content": answer})
gradio_format = make_pairs([a["content"] for a in messages[1:]])
return gradio_format, messages, sources
def connect(text):
openai.api_key = text
return "You're all set"
with gr.Blocks(title="Eki IPCC Explorer") as demo:
with gr.Row():
with gr.Column():
api_key = gr.Textbox(label="Open AI api key")
connect_btn = gr.Button(value="Connect")
with gr.Column():
result = gr.Textbox(label="Connection")
connect_btn.click(connect, inputs=api_key, outputs=result, api_name="Connection")
gr.Markdown(
"""
# Ask me anything, I'm an IPCC report
"""
)
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot()
state = gr.State([system_template])
with gr.Row():
ask = gr.Textbox(
show_label=False, placeholder="Enter text and press enter"
).style(container=False)
with gr.Column(scale=1, variant="panel"):
gr.Markdown("### Sources")
sources_textbox = gr.Textbox(
interactive=False, show_label=False, max_lines=50
)
ask.submit(
fn=gen_conv, inputs=[ask, state], outputs=[chatbot, state, sources_textbox]
)
demo.launch(share=True)