Spaces:
Sleeping
Sleeping
File size: 3,081 Bytes
75148a1 76ce883 75148a1 f12e339 75148a1 26e6019 75148a1 aed5b3a 75148a1 068bab1 75148a1 26e6019 9887b4f f12e339 75148a1 76ce883 75148a1 1a11e20 aed5b3a 75148a1 ad19f21 75148a1 c5fe6e8 75148a1 c5fe6e8 75148a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
import streamlit as st
from pymilvus import MilvusClient
import torch
from model import encode_dpr_question, get_dpr_encoder
from model import summarize_text, get_summarizer
from model import ask_reader, get_reader
TITLE = 'ReSRer: Retriever-Summarizer-Reader'
INITIAL = "What is the population of NYC"
st.set_page_config(page_title=TITLE)
st.header(TITLE)
st.markdown('''
<h5>Ask short-answer question that can be find in Wikipedia data.</h5>
''', unsafe_allow_html=True)
st.markdown(
'This demo searches through 21,000,000 Wikipedia passages in real-time under the hood.')
@st.cache_resource
def load_models():
models = {}
models['encoder'] = get_dpr_encoder()
models['summarizer'] = get_summarizer()
models['reader'] = get_reader()
return models
@st.cache_resource
def load_client():
client = MilvusClient(user='resrer', password=os.environ['MILVUS_PW'],
uri=f"http://{os.environ['MILVUS_HOST']}:19530", db_name='psgs_w100')
return client
client = load_client()
models = load_models()
styl = """
<style>
.StatusWidget-enter-done{
position: fixed;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
}
.StatusWidget-enter-done button{
display: none;
}
</style>
"""
st.markdown(styl, unsafe_allow_html=True)
question = st.text_input("Question", INITIAL)
col1, col2, col3 = st.columns(3)
if col1.button("What is the capital of South Korea"):
question = "What is the capital of South Korea"
if col2.button("What is the most famous building in Paris"):
question = "What is the most famous building in Paris"
if col3.button("Who is the actor of Harry Potter"):
question = "Who is the actor of Harry Potter"
@torch.inference_mode()
def main(question: str):
if question in st.session_state:
print("Cache hit!")
ctx, summary, answer = st.session_state[question]
else:
print(f"Input: {question}")
# Embedding
question_vectors = encode_dpr_question(
models['encoder'][0], models['encoder'][1], [question])
query_vector = question_vectors.detach().cpu().numpy().tolist()[0]
# Retriever
results = client.search(collection_name='dpr_nq', data=[
query_vector], limit=10, output_fields=['title', 'text'])
texts = [result['entity']['text'] for result in results[0]]
ctx = '\n'.join(texts)
# Reader
[summary] = summarize_text(models['summarizer'][0],
models['summarizer'][1], [ctx])
answers = ask_reader(models['reader'][0],
models['reader'][1], [question], [summary])
answer = answers[0]['answer']
print(f"\nAnswer: {answer}")
st.session_state[question] = (ctx, summary, answer)
# Summary
st.write(f"### Answer: {answer}")
st.markdown('<h5>Summarized Context</h5>', unsafe_allow_html=True)
st.markdown(
f"<h6 style='padding: 0'>{summary}</h6><hr style='margin: 1em 0px'>", unsafe_allow_html=True)
st.markdown('<h5>Original Context</h5>', unsafe_allow_html=True)
st.markdown(ctx)
if question:
main(question)
|