File size: 3,908 Bytes
8e3f751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import arxiv
import click
import json
import os

from datetime import datetime
from tqdm import tqdm


# TODO: add error handling for invalid search queries
# TODO: modify script to allow all the results for a given search query till exhaustion. max_results not specified
@click.command()
@click.option(
    "--search_query", default="astro-ph", help="Search query for arXiv papers"
)
@click.option("--max_results", default=1000, help="Maximum number of results to fetch")
@click.option(
    "--sort_by",
    type=click.Choice(["relevance", "last_updated_date", "submitted_date"]),
    default="last_updated_date",
    help="Criterion to sort results by",
)
@click.option(
    "--sort_order",
    type=click.Choice(["asc", "desc"]),
    default="desc",
    help="Sort order (ascending or descending)",
)
@click.option(
    "--out_dir",
    default=None,
    help="Output directory for the fetched data",
)
@click.option("--out_file", default=None, help="Output file name")
@click.option(
    "--annotations_file",
    default="data/manual/human_annotations.jsonl",
    help="File with manual annotations that is reserved",
)
def main(
    search_query, max_results, sort_by, sort_order, out_dir, out_file, annotations_file
):
    annotations = []
    with open(annotations_file, "r") as f:
        for line in f:
            annotations.append(json.loads(line))

    titles = set(ex["title"] for ex in annotations)
    assert len(titles) == len(annotations)

    results = get_arxiv_papers(search_query, max_results, sort_by, sort_order, titles)
    if len(results) == 0:
        click.echo("No new results to save.")
        return

    data_dir = os.path.join("data", "raw", out_dir)
    os.makedirs(data_dir, exist_ok=True)
    if out_file is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        out_file = (
            f"{search_query}_{max_results}_{sort_by}_{sort_order}_{timestamp}.jsonl"
        )

    with open(os.path.join(data_dir, out_file), "w") as f:
        for result in results:
            f.write(json.dumps(result) + "\n")

    click.echo(f"Saved {len(results)} results to {out_file}")


def get_arxiv_papers(
    search_query, max_results, sort_by="relevance", sort_order="desc", annotated=None
):
    client = arxiv.Client()

    sort_criterion = {
        "relevance": arxiv.SortCriterion.Relevance,
        "last_updated_date": arxiv.SortCriterion.LastUpdatedDate,
        "submitted_date": arxiv.SortCriterion.SubmittedDate,
    }[sort_by]

    sort_order = (
        arxiv.SortOrder.Descending
        if sort_order == "desc"
        else arxiv.SortOrder.Ascending
    )
    search = arxiv.Search(
        query=search_query,
        max_results=None,
        sort_by=sort_criterion,
        sort_order=sort_order,
    )

    non_overlapping_results = []
    pbar = tqdm(total=max_results, desc="Fetching papers")
    for result in client.results(search):
        if result.title not in annotated:
            non_overlapping_results.append(
                {
                    "id": result.entry_id,
                    "title": result.title,
                    "authors": [author.name for author in result.authors],
                    "abstract": result.summary,
                    "published": result.published.isoformat(),
                    "updated": result.updated.isoformat(),
                    "pdf_url": result.pdf_url,
                    "doi": result.doi,
                    "links": [link.href for link in result.links],
                    "journal_reference": result.journal_ref,
                    "primary_category": result.primary_category,
                    "categories": result.categories,
                }
            )
            pbar.update(1)
            if len(non_overlapping_results) >= max_results:
                break

    pbar.close()

    return non_overlapping_results


if __name__ == "__main__":
    main()