TrendFlow / mypages /home.py
Adapting's picture
1
0c73150
raw
history blame
6.59 kB
import streamlit as st
from .sidebar import render_sidebar
from requests_toolkit import ArxivQuery,IEEEQuery,PaperWithCodeQuery
from trendflow.lrt.clustering.clusters import SingleCluster
from trendflow.lrt.clustering.config import Configuration
from trendflow.lrt import ArticleList, LiteratureResearchTool
from trendflow.lrt_instance import *
from .charts import build_bar_charts
def home():
# sidebar content
platforms, number_papers, start_year, end_year, hyperparams = render_sidebar()
# body head
with st.form("my_form", clear_on_submit=False):
st.markdown('''# 👋 Hi, enter your query here :)''')
query_input = st.text_input(
'Enter your query:',
placeholder='''e.g. "Machine learning"''',
# label_visibility='collapsed',
value=''
)
show_preview = st.checkbox('show paper preview')
# Every form must have a submit button.
submitted = st.form_submit_button("Search")
if submitted:
# body
render_body(platforms, number_papers, 5, query_input,
show_preview, start_year, end_year,
hyperparams,
hyperparams['standardization'])
def __preview__(platforms, num_papers, num_papers_preview, query_input, start_year, end_year):
with st.spinner('Searching...'):
paperInGeneral = st.empty() # paper的大概
paperInGeneral_md = '''# 0 Query Results Preview
We have found following papers for you! (displaying 5 papers for each literature platforms)
'''
if 'IEEE' in platforms:
paperInGeneral_md += '''## IEEE
| ID| Paper Title | Publication Year |
| -------- | -------- | -------- |
'''
IEEEQuery.__setup_api_key__('vpd9yy325enruv27zj2d353e')
ieee = IEEEQuery.query(query_input, start_year, end_year, num_papers)
num_papers_preview = min(len(ieee), num_papers_preview)
for i in range(num_papers_preview):
title = str(ieee[i]['title']).replace('\n', ' ')
publication_year = str(ieee[i]['publication_year']).replace('\n', ' ')
paperInGeneral_md += f'''|{i + 1}|{title}|{publication_year}|\n'''
if 'Arxiv' in platforms:
paperInGeneral_md += '''
## Arxiv
| ID| Paper Title | Publication Year |
| -------- | -------- | -------- |
'''
arxiv = ArxivQuery.query(query_input, max_results=num_papers)
num_papers_preview = min(len(arxiv), num_papers_preview)
for i in range(num_papers_preview):
title = str(arxiv[i]['title']).replace('\n', ' ')
publication_year = str(arxiv[i]['published']).replace('\n', ' ')
paperInGeneral_md += f'''|{i + 1}|{title}|{publication_year}|\n'''
if 'Paper with Code' in platforms:
paperInGeneral_md += '''
## Paper with Code
| ID| Paper Title | Publication Year |
| -------- | -------- | -------- |
'''
pwc = PaperWithCodeQuery.query(query_input, items_per_page=num_papers)
num_papers_preview = min(len(pwc), num_papers_preview)
for i in range(num_papers_preview):
title = str(pwc[i]['title']).replace('\n', ' ')
publication_year = str(pwc[i]['published']).replace('\n', ' ')
paperInGeneral_md += f'''|{i + 1}|{title}|{publication_year}|\n'''
paperInGeneral.markdown(paperInGeneral_md)
def render_body(platforms, num_papers, num_papers_preview, query_input, show_preview: bool, start_year, end_year,
hyperparams: dict, standardization=False):
tmp = st.empty()
if query_input != '':
tmp.markdown(f'You entered query: `{query_input}`')
# preview
if show_preview:
__preview__(platforms, num_papers, num_papers_preview, query_input, start_year, end_year)
with st.spinner("Clustering and generating..."):
# lrt results
## baseline
if hyperparams['dimension_reduction'] == 'none' \
and hyperparams['model_cpt'] == 'keyphrase-transformer' \
and hyperparams['cluster_model'] == 'kmeans-euclidean':
model = baseline_lrt
else:
config = Configuration(
plm='''all-mpnet-base-v2''',
dimension_reduction=hyperparams['dimension_reduction'],
clustering=hyperparams['cluster_model'],
keywords_extraction=hyperparams['model_cpt']
)
model = LiteratureResearchTool(config)
generator = model.yield_(query_input, num_papers, start_year, end_year, max_k=hyperparams['max_k'],
platforms=platforms, standardization=standardization)
for i, plat in enumerate(platforms):
clusters, articles = next(generator)
st.markdown(f'''# {i + 1} {plat} Results''')
clusters.sort()
st.markdown(f'''## {i + 1}.1 Clusters Overview''')
st.markdown(f'''In this section we show the overview of the clusters, more specifically,''')
st.markdown(f'''\n- the number of papers in each cluster\n- the number of keyphrases of each cluster''')
st.bokeh_chart(build_bar_charts(
x_range=[f'Cluster {i + 1}' for i in range(len(clusters))],
y_names=['Number of Papers', 'Number of Keyphrases'],
y_data=[[len(c) for c in clusters], [len(c.get_keyphrases()) for c in clusters]]
))
st.markdown(f'''## {i + 1}.2 Cluster Details''')
st.markdown(f'''In this section we show the details of each cluster, including''')
st.markdown(f'''\n- the article information in the cluster\n- the keyphrases of the cluster''')
for j, cluster in enumerate(clusters):
assert isinstance(cluster, SingleCluster) # TODO: remove this line
ids = cluster.get_elements()
articles_in_cluster = ArticleList([articles[id] for id in ids])
st.markdown(f'''**Cluster {j + 1}**''')
st.dataframe(articles_in_cluster.to_dataframe())
st.markdown(f'''The top 5 keyphrases of this cluster are:''')
md = ''
for keyphrase in cluster.top_5_keyphrases:
md += f'''- `{keyphrase}`\n'''
st.markdown(md)