gaia / app.py
ola13's picture
Migrate from Gradio to Streamlit (#1)
f077846
raw
history blame
8.71 kB
import http.client as http_client
import json
import logging
import os
import pprint
import re
import string
import streamlit as st
import streamlit.components.v1 as components
import requests
pp = pprint.PrettyPrinter(indent=2)
st.set_page_config(page_title="Gaia Search", layout="wide")
os.makedirs(os.path.join(os.getcwd(),".streamlit"), exist_ok = True)
with open(os.path.join(os.getcwd(),".streamlit/config.toml"), "w") as file:
file.write(
'[theme]\nbase="light"'
)
LANG_MAPPING = {'Arabic':'ar',
'Catalan':'ca',
'Code':'code',
'English':'en',
'Spanish':'es',
'French':'fr',
'Indonesian':'id',
'Indic':'indic',
'Niger-Congo':'nigercongo',
'Portuguese': 'pt',
'Vietnamese': 'vi',
'Chinese': 'zh',
'Detect Language':'detect_language',
'All':'all'}
st.sidebar.markdown(
"""
<style>
.aligncenter {
text-align: center;
font-weight: bold;
font-size: 50px;
}
</style>
<p class="aligncenter">Gaia Search 🌖🌏</p>
<p style="text-align: center;"> A search engine for the LAION large scale image caption corpora</p>
""",
unsafe_allow_html=True,
)
st.sidebar.markdown(
"""
<style>
.aligncenter {
text-align: center;
}
</style>
<p style='text-align: center'>
<a href="" >GitHub</a> | <a href="" >Project Report</a>
</p>
<p class="aligncenter">
<a href="" target="_blank">
<img src="https://colab.research.google.com/assets/colab-badge.svg"/>
</a>
</p>
""",
unsafe_allow_html=True,
)
query = st.sidebar.text_input(label='Search query', value='')
language = st.sidebar.selectbox(
'Language',
('Arabic', 'Catalan', 'Code', 'English', 'Spanish', 'French', 'Indonesian', 'Indic', 'Niger-Congo', 'Portuguese', 'Vietnamese', 'Chinese', 'Detect Language', 'All'),
index=3)
max_results = st.sidebar.slider(
"Maximum Number of Results",
min_value=1,
max_value=100,
step=1,
value=10,
help="Maximum Number of Documents to return",
)
footer="""<style>
.footer {
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: white;
color: black;
text-align: center;
}
</style>
<div class="footer">
<p>Powered by <a href="https://huggingface.co/" >HuggingFace 🤗</a> and <a href="https://github.com/castorini/pyserini" >Pyserini 🦆</a></p>
</div>
"""
st.sidebar.markdown(footer,unsafe_allow_html=True)
def scisearch(query, language, num_results=10):
try:
query = query.strip()
if query == "" or query is None:
return
post_data = {"query": query, "k": num_results}
if language != "detect_language":
post_data["lang"] = language
output = requests.post(
os.environ.get("address"),
headers={"Content-type": "application/json"},
data=json.dumps(post_data),
timeout=60,
)
payload = json.loads(output.text)
if "err" in payload:
if payload["err"]["type"] == "unsupported_lang":
detected_lang = payload["err"]["meta"]["detected_lang"]
return f"""
<p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'>
Detected language <b>{detected_lang}</b> is not supported.<br>
Please choose a language from the dropdown or type another query.
</p><br><hr><br>"""
results = payload["results"]
highlight_terms = payload["highlight_terms"]
except Exception as e:
results_html = f"""
<p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'>
Raised {type(e).__name__}</p>
<p style='font-size:14px; font-family: Arial; '>
Check if a relevant discussion already exists in the Community tab. If not, please open a discussion.
</p>
"""
print(e)
return results, highlight_terms
PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"}
PII_PREFIX = "PI:"
def process_pii(text):
for tag in PII_TAGS:
text = text.replace(
PII_PREFIX + tag,
"""<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format(tag),
)
return text
def highlight_string(paragraph: str, highlight_terms: list) -> str:
for term in highlight_terms:
paragraph = re.sub(f"\\b{term}\\b", f"<b>{term}</b>", paragraph, flags=re.I)
paragraph = process_pii(paragraph)
return paragraph
def process_results(hits: list, highlight_terms: list) -> str:
hit_list = []
for i, hit in enumerate(hits):
res_head = f"""
<div class="searchresult">
<h2>{i+1}. Document ID: {hit['docid']}</h2>
<p>Language: <string>{hit['lang']}</string>, Score: {round(hit['score'], 2)}</p>
"""
for subhit in hit['meta']['docs']:
res_head += f"""
<button onclick="load_image({subhit['_id']})">Load Image</button><br>
<p><img id='{subhit['_id']}' src='{subhit['URL']}' style="width:400px;height:auto;display:none;"></p>
<a href='{subhit['URL']}'>{subhit['URL']}</a>
<p>{highlight_string(subhit['TEXT'], highlight_terms)}</p>
"""
res_head += f"""
<p>{highlight_string(hit['text'], highlight_terms)}</p>
</div>
<hr>
"""
hit_list.append(res_head)
return " ".join(hit_list)
if st.sidebar.button("Search"):
hits, highlight_terms = scisearch(query, LANG_MAPPING[language], max_results)
html_results = process_results(hits, highlight_terms)
rendered_results = f"""
<div id="searchresultsarea">
<br>
<p id="searchresultsnumber">About {max_results} results</p>
{html_results}
</div>
"""
st.markdown("""
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC" crossorigin="anonymous">
""",
unsafe_allow_html=True)
st.markdown(
"""
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
""",
unsafe_allow_html=True)
st.markdown(
f"""
<div class="row no-gutters mt-3 align-items-center">
Gaia Search 🌖🌏
<div class="col col-md-4">
<input class="form-control border-secondary rounded-pill pr-5" type="search" value="{query}" id="example-search-input2">
</div>
<div class="col-auto">
<button class="btn btn-outline-light text-dark border-0 rounded-pill ml-n5" type="button">
<i class="fa fa-search"></i>
</button>
</div>
</div>
""",
unsafe_allow_html=True)
components.html(
"""
<style>
#searchresultsarea {
font-family: 'Arial';
}
#searchresultsnumber {
font-size: 0.8rem;
color: gray;
}
.searchresult h2 {
font-size: 19px;
line-height: 18px;
font-weight: normal;
color: rgb(7, 111, 222);
margin-bottom: 0px;
margin-top: 25px;
}
.searchresult a {
font-size: 12px;
line-height: 12px;
color: green;
margin-bottom: 0px;
}
.dark-mode {
color: white;
}
</style>
<script>
function load_image(id){
console.log(id)
var x = document.getElementById(id);
console.log(x)
if (x.style.display === "none") {
x.style.display = "block";
} else {
x.style.display = "none";
}
};
function myFunction() {
var element = document.body;
element.classList.toggle("dark-mode");
}
</script>
<button onclick="myFunction()">Toggle dark mode</button>
""" + rendered_results, height=800, scrolling=True
)