Spaces:
Running
Running
import http.client as http_client | |
import json | |
import logging | |
import os | |
import pprint | |
import re | |
import string | |
import streamlit as st | |
import streamlit.components.v1 as components | |
import requests | |
pp = pprint.PrettyPrinter(indent=2) | |
st.set_page_config(page_title="Gaia Search", layout="wide") | |
os.makedirs(os.path.join(os.getcwd(),".streamlit"), exist_ok = True) | |
with open(os.path.join(os.getcwd(),".streamlit/config.toml"), "w") as file: | |
file.write( | |
'[theme]\nbase="light"' | |
) | |
LANG_MAPPING = {'Arabic':'ar', | |
'Catalan':'ca', | |
'Code':'code', | |
'English':'en', | |
'Spanish':'es', | |
'French':'fr', | |
'Indonesian':'id', | |
'Indic':'indic', | |
'Niger-Congo':'nigercongo', | |
'Portuguese': 'pt', | |
'Vietnamese': 'vi', | |
'Chinese': 'zh', | |
'Detect Language':'detect_language', | |
'All':'all'} | |
st.sidebar.markdown( | |
""" | |
<style> | |
.aligncenter { | |
text-align: center; | |
font-weight: bold; | |
font-size: 50px; | |
} | |
</style> | |
<p class="aligncenter">Gaia Search 🌖🌏</p> | |
<p style="text-align: center;"> A search engine for the LAION large scale image caption corpora</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.markdown( | |
""" | |
<style> | |
.aligncenter { | |
text-align: center; | |
} | |
</style> | |
<p style='text-align: center'> | |
<a href="" >GitHub</a> | <a href="" >Project Report</a> | |
</p> | |
<p class="aligncenter"> | |
<a href="" target="_blank"> | |
<img src="https://colab.research.google.com/assets/colab-badge.svg"/> | |
</a> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
query = st.sidebar.text_input(label='Search query', value='') | |
language = st.sidebar.selectbox( | |
'Language', | |
('Arabic', 'Catalan', 'Code', 'English', 'Spanish', 'French', 'Indonesian', 'Indic', 'Niger-Congo', 'Portuguese', 'Vietnamese', 'Chinese', 'Detect Language', 'All'), | |
index=3) | |
max_results = st.sidebar.slider( | |
"Maximum Number of Results", | |
min_value=1, | |
max_value=100, | |
step=1, | |
value=10, | |
help="Maximum Number of Documents to return", | |
) | |
footer="""<style> | |
.footer { | |
position: fixed; | |
left: 0; | |
bottom: 0; | |
width: 100%; | |
background-color: white; | |
color: black; | |
text-align: center; | |
} | |
</style> | |
<div class="footer"> | |
<p>Powered by <a href="https://huggingface.co/" >HuggingFace 🤗</a> and <a href="https://github.com/castorini/pyserini" >Pyserini 🦆</a></p> | |
</div> | |
""" | |
st.sidebar.markdown(footer,unsafe_allow_html=True) | |
def scisearch(query, language, num_results=10): | |
try: | |
query = query.strip() | |
if query == "" or query is None: | |
return | |
post_data = {"query": query, "k": num_results} | |
if language != "detect_language": | |
post_data["lang"] = language | |
output = requests.post( | |
os.environ.get("address"), | |
headers={"Content-type": "application/json"}, | |
data=json.dumps(post_data), | |
timeout=60, | |
) | |
payload = json.loads(output.text) | |
if "err" in payload: | |
if payload["err"]["type"] == "unsupported_lang": | |
detected_lang = payload["err"]["meta"]["detected_lang"] | |
return f""" | |
<p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'> | |
Detected language <b>{detected_lang}</b> is not supported.<br> | |
Please choose a language from the dropdown or type another query. | |
</p><br><hr><br>""" | |
results = payload["results"] | |
highlight_terms = payload["highlight_terms"] | |
except Exception as e: | |
results_html = f""" | |
<p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'> | |
Raised {type(e).__name__}</p> | |
<p style='font-size:14px; font-family: Arial; '> | |
Check if a relevant discussion already exists in the Community tab. If not, please open a discussion. | |
</p> | |
""" | |
print(e) | |
return results, highlight_terms | |
PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"} | |
PII_PREFIX = "PI:" | |
def process_pii(text): | |
for tag in PII_TAGS: | |
text = text.replace( | |
PII_PREFIX + tag, | |
"""<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format(tag), | |
) | |
return text | |
def highlight_string(paragraph: str, highlight_terms: list) -> str: | |
for term in highlight_terms: | |
paragraph = re.sub(f"\\b{term}\\b", f"<b>{term}</b>", paragraph, flags=re.I) | |
paragraph = process_pii(paragraph) | |
return paragraph | |
def process_results(hits: list, highlight_terms: list) -> str: | |
hit_list = [] | |
for i, hit in enumerate(hits): | |
res_head = f""" | |
<div class="searchresult"> | |
<h2>{i+1}. Document ID: {hit['docid']}</h2> | |
<p>Language: <string>{hit['lang']}</string>, Score: {round(hit['score'], 2)}</p> | |
""" | |
for subhit in hit['meta']['docs']: | |
res_head += f""" | |
<button onclick="load_image({subhit['_id']})">Load Image</button><br> | |
<p><img id='{subhit['_id']}' src='{subhit['URL']}' style="width:400px;height:auto;display:none;"></p> | |
<a href='{subhit['URL']}'>{subhit['URL']}</a> | |
<p>{highlight_string(subhit['TEXT'], highlight_terms)}</p> | |
""" | |
res_head += f""" | |
<p>{highlight_string(hit['text'], highlight_terms)}</p> | |
</div> | |
<hr> | |
""" | |
hit_list.append(res_head) | |
return " ".join(hit_list) | |
if st.sidebar.button("Search"): | |
hits, highlight_terms = scisearch(query, LANG_MAPPING[language], max_results) | |
html_results = process_results(hits, highlight_terms) | |
rendered_results = f""" | |
<div id="searchresultsarea"> | |
<br> | |
<p id="searchresultsnumber">About {max_results} results</p> | |
{html_results} | |
</div> | |
""" | |
st.markdown(""" | |
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/css/bootstrap.min.css" rel="stylesheet" | |
integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC" crossorigin="anonymous"> | |
""", | |
unsafe_allow_html=True) | |
st.markdown( | |
""" | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css"> | |
""", | |
unsafe_allow_html=True) | |
st.markdown( | |
f""" | |
<div class="row no-gutters mt-3 align-items-center"> | |
Gaia Search 🌖🌏 | |
<div class="col col-md-4"> | |
<input class="form-control border-secondary rounded-pill pr-5" type="search" value="{query}" id="example-search-input2"> | |
</div> | |
<div class="col-auto"> | |
<button class="btn btn-outline-light text-dark border-0 rounded-pill ml-n5" type="button"> | |
<i class="fa fa-search"></i> | |
</button> | |
</div> | |
</div> | |
""", | |
unsafe_allow_html=True) | |
components.html( | |
""" | |
<style> | |
#searchresultsarea { | |
font-family: 'Arial'; | |
} | |
#searchresultsnumber { | |
font-size: 0.8rem; | |
color: gray; | |
} | |
.searchresult h2 { | |
font-size: 19px; | |
line-height: 18px; | |
font-weight: normal; | |
color: rgb(7, 111, 222); | |
margin-bottom: 0px; | |
margin-top: 25px; | |
} | |
.searchresult a { | |
font-size: 12px; | |
line-height: 12px; | |
color: green; | |
margin-bottom: 0px; | |
} | |
.dark-mode { | |
color: white; | |
} | |
</style> | |
<script> | |
function load_image(id){ | |
console.log(id) | |
var x = document.getElementById(id); | |
console.log(x) | |
if (x.style.display === "none") { | |
x.style.display = "block"; | |
} else { | |
x.style.display = "none"; | |
} | |
}; | |
function myFunction() { | |
var element = document.body; | |
element.classList.toggle("dark-mode"); | |
} | |
</script> | |
<button onclick="myFunction()">Toggle dark mode</button> | |
""" + rendered_results, height=800, scrolling=True | |
) |