Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
import datetime | |
import operator | |
import pandas as pd | |
import tqdm.auto | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from huggingface_hub import HfApi | |
from ragatouille import RAGPretrainedModel | |
import gradio as gr | |
from gradio_calendar import Calendar | |
import datasets | |
# --- Data Loading and Processing --- | |
api = HfApi() | |
INDEX_REPO_ID = "hysts-bot-data/daily-papers-abstract-index" | |
INDEX_DIR_PATH = ".ragatouille/colbert/indexes/daily-papers-abstract-index/" | |
api.snapshot_download( | |
repo_id=INDEX_REPO_ID, | |
repo_type="dataset", | |
local_dir=INDEX_DIR_PATH, | |
) | |
abstract_retriever = RAGPretrainedModel.from_index(INDEX_DIR_PATH) | |
# Run once to initialize the retriever | |
abstract_retriever.search("LLM") | |
def update_abstract_index() -> None: | |
global abstract_retriever | |
api.snapshot_download( | |
repo_id=INDEX_REPO_ID, | |
repo_type="dataset", | |
local_dir=INDEX_DIR_PATH, | |
) | |
abstract_retriever = RAGPretrainedModel.from_index(INDEX_DIR_PATH) | |
abstract_retriever.search("LLM") | |
scheduler_abstract = BackgroundScheduler() | |
scheduler_abstract.add_job( | |
func=update_abstract_index, | |
trigger="cron", | |
minute=0, # Every hour at minute 0 | |
timezone="UTC", | |
misfire_grace_time=3 * 60, | |
) | |
scheduler_abstract.start() | |
def get_df() -> pd.DataFrame: | |
df = pd.merge( | |
left=datasets.load_dataset("hysts-bot-data/daily-papers", split="train").to_pandas(), | |
right=datasets.load_dataset("hysts-bot-data/daily-papers-stats", split="train").to_pandas(), | |
on="arxiv_id", | |
) | |
df = df[::-1].reset_index(drop=True) | |
df["date"] = df["date"].dt.strftime("%Y-%m-%d") | |
paper_info = [] | |
for _, row in tqdm.auto.tqdm(df.iterrows(), total=len(df)): | |
info = row.copy() | |
del info["abstract"] | |
info["paper_page"] = f"https://huggingface.co/papers/{row.arxiv_id}" | |
paper_info.append(info) | |
return pd.DataFrame(paper_info) | |
class Prettifier: | |
def get_github_link(link: str) -> str: | |
if not link: | |
return "" | |
return Prettifier.create_link("github", link) | |
def create_link(text: str, url: str) -> str: | |
return f'<a href="{url}" target="_blank">{text}</a>' | |
def to_div(text: str | None, category_name: str) -> str: | |
if text is None: | |
text = "" | |
class_name = f"{category_name}-{text.lower()}" | |
return f'<div class="{class_name}">{text}</div>' | |
def __call__(self, df: pd.DataFrame) -> pd.DataFrame: | |
new_rows = [] | |
for _, row in df.iterrows(): | |
new_row = { | |
"date": Prettifier.create_link(row.date, f"https://huggingface.co/papers?date={row.date}"), | |
"paper_page": Prettifier.create_link(row.arxiv_id, row.paper_page), | |
"title": row["title"], | |
"github": self.get_github_link(row.github), | |
"๐": row["upvotes"], | |
"๐ฌ": row["num_comments"], | |
} | |
new_rows.append(new_row) | |
return pd.DataFrame(new_rows) | |
class PaperList: | |
COLUMN_INFO = [ | |
["date", "markdown"], | |
["paper_page", "markdown"], | |
["title", "str"], | |
["github", "markdown"], | |
["๐", "number"], | |
["๐ฌ", "number"], | |
] | |
def __init__(self, df: pd.DataFrame): | |
self.df_raw = df | |
self._prettifier = Prettifier() | |
self.df_prettified = self._prettifier(df).loc[:, self.column_names] | |
def column_names(self): | |
return list(map(operator.itemgetter(0), self.COLUMN_INFO)) | |
def column_datatype(self): | |
return list(map(operator.itemgetter(1), self.COLUMN_INFO)) | |
def search( | |
self, | |
start_date: datetime.datetime, | |
end_date: datetime.datetime, | |
title_search_query: str, | |
abstract_search_query: str, | |
max_num_to_retrieve: int, | |
) -> pd.DataFrame: | |
df = self.df_raw.copy() | |
df["date"] = pd.to_datetime(df["date"]) | |
# Filter by date | |
df = df[(df["date"] >= start_date) & (df["date"] <= end_date)] | |
df["date"] = df["date"].dt.strftime("%Y-%m-%d") | |
# Filter by title | |
if title_search_query: | |
df = df[df["title"].str.contains(title_search_query, case=False)] | |
# Filter by abstract | |
if abstract_search_query: | |
results = abstract_retriever.search(abstract_search_query, k=max_num_to_retrieve) | |
remaining_ids = set(df["arxiv_id"]) | |
found_id_set = set() | |
found_ids = [] | |
for x in results: | |
arxiv_id = x["document_id"] | |
if arxiv_id not in remaining_ids: | |
continue | |
if arxiv_id in found_id_set: | |
continue | |
found_id_set.add(arxiv_id) | |
found_ids.append(arxiv_id) | |
df = df[df["arxiv_id"].isin(found_ids)].set_index("arxiv_id").reindex(index=found_ids).reset_index() | |
df_prettified = self._prettifier(df).loc[:, self.column_names] | |
return df_prettified | |
paper_list = PaperList(get_df()) | |
def update_paper_list() -> None: | |
global paper_list | |
paper_list = PaperList(get_df()) | |
scheduler_data = BackgroundScheduler() | |
scheduler_data.add_job( | |
func=update_paper_list, | |
trigger="cron", | |
minute=0, # Every hour at minute 0 | |
timezone="UTC", | |
misfire_grace_time=60, | |
) | |
scheduler_data.start() | |
# --- Gradio App --- | |
DESCRIPTION = "# [Daily Papers](https://huggingface.co/papers)" | |
FOOT_NOTE = """\ | |
Related useful Spaces: | |
- [Semantic Scholar Paper Recommender](https://huggingface.co/spaces/librarian-bots/recommend_similar_papers) by [davanstrien](https://huggingface.co/davanstrien) | |
- [ArXiv CS RAG](https://huggingface.co/spaces/bishmoy/Arxiv-CS-RAG) by [bishmoy](https://huggingface.co/bishmoy) | |
- [Paper Q&A](https://huggingface.co/spaces/chansung/paper_qa) by [chansung](https://huggingface.co/chansung) | |
""" | |
def update_df() -> pd.DataFrame: | |
return paper_list.df_prettified | |
def update_num_papers(df: pd.DataFrame) -> str: | |
return f"{len(df)} / {len(paper_list.df_raw)}" | |
def search( | |
start_date: datetime.datetime, | |
end_date: datetime.datetime, | |
search_title: str, | |
search_abstract: str, | |
max_num_to_retrieve: int, | |
) -> pd.DataFrame: | |
return paper_list.search(start_date, end_date, search_title, search_abstract, max_num_to_retrieve) | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Group(): | |
search_title = gr.Textbox(label="Search title") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
search_abstract = gr.Textbox( | |
label="Search abstract", | |
info="The result may not be accurate as the abstract does not contain all the information.", | |
) | |
with gr.Column(scale=1): | |
max_num_to_retrieve = gr.Slider( | |
label="Max number to retrieve", | |
info="This is used only for search on abstracts.", | |
minimum=1, | |
maximum=len(paper_list.df_raw), | |
step=1, | |
value=100, | |
) | |
with gr.Row(): | |
start_date = Calendar(label="Start date", type="date", value="2023-05-05") | |
end_date = Calendar(label="End date", type="date", value=datetime.datetime.utcnow().strftime("%Y-%m-%d")) | |
num_papers = gr.Textbox(label="Number of papers", value=update_num_papers(paper_list.df_raw), interactive=False) | |
df = gr.Dataframe( | |
value=paper_list.df_prettified, | |
datatype=paper_list.column_datatype, | |
type="pandas", | |
interactive=False, | |
height=1000, | |
elem_id="table", | |
column_widths=["10%", "10%", "60%", "10%", "5%", "5%"], | |
wrap=True, | |
) | |
gr.Markdown(FOOT_NOTE) | |
# Define the triggers and corresponding functions | |
search_event = gr.Button("Search") | |
search_event.click( | |
fn=search, | |
inputs=[start_date, end_date, search_title, search_abstract, max_num_to_retrieve], | |
outputs=df, | |
).then( | |
fn=update_num_papers, | |
inputs=df, | |
outputs=num_papers, | |
queue=False, | |
) | |
# Automatically trigger search when inputs change | |
for trigger in [start_date, end_date, search_title, search_abstract, max_num_to_retrieve]: | |
trigger.change( | |
fn=search, | |
inputs=[start_date, end_date, search_title, search_abstract, max_num_to_retrieve], | |
outputs=df, | |
).then( | |
fn=update_num_papers, | |
inputs=df, | |
outputs=num_papers, | |
queue=False, | |
) | |
# Load the initial dataframe and number of papers | |
demo.load( | |
fn=update_df, | |
outputs=df, | |
queue=False, | |
).then( | |
fn=update_num_papers, | |
inputs=df, | |
outputs=num_papers, | |
queue=False, | |
) | |
if __name__ == "__main__": | |
demo.queue(api_open=False).launch(show_api=False) |