File size: 14,303 Bytes
985ef3e
5800f83
 
 
 
 
 
abf93bc
 
985ef3e
3a03f47
985ef3e
 
22f7397
 
 
 
 
 
7bf4167
22f7397
 
7c5f9b0
 
 
 
 
22f7397
 
 
 
 
 
 
 
 
 
7bf4167
 
 
fc74e8d
 
 
 
 
955747f
 
 
 
4d14899
 
955747f
a0eb5f6
 
 
 
 
 
a2d12d3
 
 
 
a0eb5f6
 
 
 
184fac5
 
 
 
 
4d14899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
955747f
4d14899
184fac5
 
4d14899
 
 
 
 
 
c985084
4d14899
 
 
c985084
4d14899
 
c985084
4d14899
 
c985084
4d14899
 
955747f
c985084
 
4d14899
 
c985084
4d14899
184fac5
955747f
4d14899
 
 
c985084
 
4d14899
fc74e8d
 
a0eb5f6
 
 
 
4ea4863
 
 
a0eb5f6
 
 
 
 
 
b57342d
a0eb5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4419e34
a0eb5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
4ea4863
 
a0eb5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d68fabe
88d941f
c65d57e
 
4d14899
 
 
c65d57e
4d14899
 
 
c65d57e
4d14899
c65d57e
4d14899
 
 
 
c65d57e
4d14899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d68fabe
 
955747f
 
 
394512b
a0eb5f6
 
 
 
88d941f
394512b
6781558
 
58962c9
 
d68fabe
58962c9
394512b
88d941f
394512b
58962c9
88d941f
 
58962c9
394512b
a0eb5f6
 
22f7397
 
 
fc74e8d
955747f
fc74e8d
394512b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05560e1
394512b
fc74e8d
58962c9
4f40d11
d68fabe
394512b
 
 
d68fabe
 
394512b
 
 
 
d68fabe
394512b
 
 
 
 
 
 
 
 
d68fabe
394512b
 
 
 
 
 
fc74e8d
394512b
05560e1
fc74e8d
955747f
7bf4167
 
 
22f7397
394512b
 
 
 
 
 
 
 
 
d68fabe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bf4167
22f7397
955747f
d68fabe
 
 
955747f
c65d57e
d68fabe
c65d57e
 
 
d68fabe
 
 
 
 
 
 
 
 
 
394512b
 
 
 
 
 
 
 
955747f
 
 
c65d57e
e365271
53c4d50
955747f
 
4566fdb
 
53c4d50
 
 
611f1cc
53c4d50
 
4566fdb
 
 
 
 
 
 
 
 
 
 
 
53c4d50
 
 
4566fdb
 
955747f
53c4d50
955747f
 
22f7397
 
 
 
 
 
7bf4167
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
import os
#os.system("pip uninstall -y gradio")
#os.system("pip install --upgrade gradio")
#os.system("pip install datamapplot==0.3.0")
#os.system("pip install numba==0.59.1")
#os.system("pip install umap-learn==0.5.6")
#os.system("pip install pynndescent==0.5.12")



import spaces


from pathlib import Path
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn
import gradio as gr
from datetime import datetime
import sys



gr.set_static_paths(paths=["static/"])



# create a FastAPI app
app = FastAPI()

# create a static directory to store the static files
static_dir = Path('./static')
static_dir.mkdir(parents=True, exist_ok=True)

# mount FastAPI StaticFiles server
app.mount("/static", StaticFiles(directory=static_dir), name="static")

# Gradio stuff



import datamapplot
import numpy as np
import requests
import io
import pandas as pd
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
from itertools import chain
from compress_pickle import load, dump
from urllib.parse import urlparse, parse_qs
import re



from transformers import AutoTokenizer
from adapters import AutoAdapterModel
import torch
from tqdm import tqdm
from numba.typed import List
import pickle
import pynndescent
import umap









def openalex_url_to_pyalex_query(url):
    """
    Convert an OpenAlex search URL to a pyalex query.
    
    Args:
    url (str): The OpenAlex search URL.
    
    Returns:
    tuple: (Works object, dict of parameters)
    """
    parsed_url = urlparse(url)
    query_params = parse_qs(parsed_url.query)
    
    # Initialize the Works object
    query = Works()
    
    # Handle filters
    if 'filter' in query_params:
        filters = query_params['filter'][0].split(',')
        for f in filters:
            if ':' in f:
                key, value = f.split(':', 1)
                if key == 'default.search':
                    query = query.search(value)
                else:
                    query = query.filter(**{key: value})
    
    # Handle sort
    if 'sort' in query_params:
        sort_params = query_params['sort'][0].split(',')
        for s in sort_params:
            if s.startswith('-'):
                query = query.sort(**{s[1:]: 'desc'})
            else:
                query = query.sort(**{s: 'asc'})
    
    # Handle other parameters
    params = {}
    for key in ['page', 'per-page', 'sample', 'seed']:
        if key in query_params:
            params[key] = query_params[key][0]
    
    return query, params


def invert_abstract(inv_index):
    if inv_index is not None:
        l_inv = [(w, p) for w, pos in inv_index.items() for p in pos]
        return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1])))
    else:
        return ' '

def get_pub(x):
    try: 
        source = x['source']['display_name']
        if source not in ['parsed_publication','Deleted Journal']:
            return source
        else: 
            return ' '
    except:
            return ' '

#def query_records(search_term):


#     # Fetch records based on the search term in the abstract!
#     query = Works().search([search_term])
#     query_length = Works().search([search_term]).count()
    
#     records = []
#     #total_pages = (query_length + 199) // 200  # Calculate total number of pages
#     progress=gr.Progress()
    
#     for i, record in progress.tqdm(enumerate(chain(*query.paginate(per_page=200)))):
#         records.append(record)
        
#         # Calculate progress from 0 to 0.1
#         #achieved_progress = min(0.1, (i + 1) / query_length * 0.1)
        
#         # Update progress bar
#         #progress(achieved_progress, desc="Getting queried data...")


    
#     records_df = pd.DataFrame(records)
#     records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']]

#     records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']]

    
#     records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ')
#     records_df['abstract'] = records_df['abstract'].fillna(' ')
#     records_df['title'] = records_df['title'].fillna(' ')

    
#     return records_df


################# Setting up the model for specter2 embeddings ###################



#device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
#print(f"Using device: {device}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_aug2023refresh_base')
model = AutoAdapterModel.from_pretrained('allenai/specter2_aug2023refresh_base')


@spaces.GPU(duration=60)
def create_embeddings(texts_to_embedd):
    # Set up the device


    print(len(texts_to_embedd))

    # Load the proximity adapter and activate it
    model.load_adapter("allenai/specter2_aug2023refresh", source="hf", load_as="proximity", set_active=True)
    model.set_active_adapters("proximity")
 
    model.to(device)

    def batch_generator(data, batch_size):
        """Yield consecutive batches of data."""
        for i in range(0, len(data), batch_size):
            yield data[i:i + batch_size]

    
    def encode_texts(texts, device, batch_size=16):
        """Process texts in batches and return their embeddings."""
        model.eval()
        with torch.no_grad():
            all_embeddings = []
            count = 0
            for batch in tqdm(batch_generator(texts, batch_size)):
                inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)
                outputs = model(**inputs)
                embeddings = outputs.last_hidden_state[:, 0, :]  # Taking the [CLS] token representation
                
                all_embeddings.append(embeddings.cpu())  # Move to CPU to free GPU memory
                #torch.mps.empty_cache()  # Clear cache to free up memory
                if count == 100:
                    #torch.mps.empty_cache()
                    torch.cuda.empty_cache()
                    count = 0
                
                count +=1
            
            all_embeddings = torch.cat(all_embeddings, dim=0)
        return all_embeddings

    # Concatenate title and abstract
    embeddings = encode_texts(texts_to_embedd, device, batch_size=32).cpu().numpy()  # Process texts in batches of 10

    return embeddings





def predict(text_input, sample_size_slider, reduce_sample_checkbox, progress=gr.Progress()):

    print('getting data to project')
    progress(0, desc="Starting...")

    query, params = openalex_url_to_pyalex_query(text_input)
    query_length = query.count()
    print(f'Requesting {query_length} entries...')
    
    records = []
    total_pages = (query_length + 199) // 200  # Calculate total number of pages
    per_page = 0.3 / total_pages
    
    for i, record in enumerate(chain(*query.paginate(per_page=200))):
        records.append(record)
        
        
        # Update progress bar
        progress(per_page * i, desc="Getting queried data...")


    
    records_df = pd.DataFrame(records)
    records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']]

    records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']]

    
    records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ')
    records_df['abstract'] = records_df['abstract'].fillna(' ')
    records_df['title'] = records_df['title'].fillna(' ')



    
    if reduce_sample_checkbox:
        records_df = records_df.sample(sample_size_slider)
    print(records_df)
    
    
    progress(0.3, desc="Embedding Data...")
    texts_to_embedd = [title + tokenizer.sep_token + publication + tokenizer.sep_token  + abstract for title, publication, abstract in zip(records_df['title'],records_df['parsed_publication'], records_df['abstract'])]
    
    embeddings = create_embeddings(texts_to_embedd)
    print(embeddings)

    progress(0.5, desc="Project into UMAP-embedding...")
    umap_embeddings = mapper.transform(embeddings)
    records_df[['x','y']] = umap_embeddings

    basedata_df['color'] = '#ced4d211'
    records_df['color'] = '#f98e31'

    progress(0.6, desc="Set up data...")

    stacked_df = pd.concat([basedata_df,records_df], axis=0, ignore_index=True)
    stacked_df = stacked_df.fillna("Unlabelled")
    stacked_df = stacked_df.reset_index(drop=True)
    print(stacked_df)

    extra_data = pd.DataFrame(stacked_df['doi'])
    
    
    file_name = f"{datetime.utcnow().strftime('%s')}.html"
    file_path = static_dir / file_name
    print(file_path)
 
   #
    
    progress(0.7, desc="Plotting...")

    custom_css = """
    
       
            #title-container {
                background: #edededaa;
                border-radius: 2px;
    
                box-shadow: 2px 3px 10px #aaaaaa00;
            }
            
    
            
        #search-container {
            position: fixed !important;
            top: 20px !important;
            right: 20px !important;
            left: auto !important;
            width: 200px !important;
            z-index: 9999 !important;
        }
        
        #search {
         //   padding: 8px 8px !important;
         //   border: none !important;
         //   border-radius: 20px !important;
            background-color: #ffffffaa !important;
            font-family: 'Roboto Condensed', sans-serif !important;
            font-size: 14px;
         //   box-shadow: 0 0px 0px #aaaaaa00 !important;
        }
    
            
    
    """

    
    plot = datamapplot.create_interactive_plot(
        stacked_df[['x','y']].values,
        np.array(stacked_df['cluster_1_labels']),np.array(stacked_df['cluster_2_labels']),np.array(stacked_df['cluster_3_labels']),
        hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()],
        marker_color_array=stacked_df['color'],

        use_medoids=True,
        width=1000,
        height=1000,
    #    title='The Science of <span style="color:#ab0b00;"> Consciousness </span>',
    #    sub_title=f'<div style="margin-top:20px;"> Large sample, n={len(dataset_df_filtered)}, embeddings with specter 2 & UMAP, labels: Claude 3.5 Sonnet </div>',
        point_radius_min_pixels=1,
        text_outline_width=5,
        point_hover_color='#5e2784',
        point_radius_max_pixels=7,
        color_label_text=False,
        font_family="Roboto Condensed",
        font_weight=700,
        tooltip_font_weight=600,
        tooltip_font_family="Roboto Condensed",
        extra_point_data=extra_data,
        on_click="window.open(`{doi}`)",
        custom_css=custom_css,
        initial_zoom_fraction=.8,
        enable_search=True)




    

    progress(0.9, desc="Saving plot...")
    plot.save(file_path)

    progress(1.0, desc="Done!")
    iframe = f"""<iframe src="/static/{file_name}" width="100%" height="500px"></iframe>"""
    link = f'<a href="/static/{file_name}" target="_blank">{file_name}</a>'
    return link, iframe






################ MAIN BLOCK #####################



# with gr.Blocks() as block:
#     gr.Markdown("""
# ## Mapping OpenAlex-Queries

# This is a tool to further interdisciplinary research – you are a neuroscientist who has used ..., What have the ... been doing with them. 
# Your a philosopher of science who wonders where the concept of a fitnesslandscape has appeared...
# """)
#     with gr.Row():
#         with gr.Column():
#             text_input = gr.Textbox(label="Name")
#             markdown = gr.Markdown(label="Output Box")
#             new_btn = gr.Button("New")
#         with gr.Column():
#             html = gr.HTML(label="HTML preview", show_label=True)

#     new_btn.click(fn=predict, inputs=[text_input], outputs=[markdown, html])



with gr.Blocks() as block:
    gr.Markdown("""
    ## Mapping OpenAlex-Queries

    Enter the URL to an OpenAlex-search below. It will take a few minutes, but then the result will be projected onto a map of the OA database as a whole.
    """)

#    This is a tool to further interdisciplinary research – you are a neuroscientist who has used ..., What have the ... been doing with them. 
#    You're a philosopher of science who wonders where the concept of a fitness landscape has appeared...
    
    with gr.Column():
        text_input = gr.Textbox(label="OpenAlex Fulltext-Search")
        sample_size_slider = gr.Slider(label="Sample Size", minimum=10, maximum=20000, step=10, value=1000)
        reduce_sample_checkbox = gr.Checkbox(label="Reduce Sample Size", value=True)
        new_btn = gr.Button("Run Query")
        markdown = gr.Markdown(label="")
        html = gr.HTML(label="HTML preview", show_label=True)
    
    new_btn.click(fn=predict, inputs=[text_input, sample_size_slider, reduce_sample_checkbox], outputs=[markdown, html])








def setup_basemap_data():
    # get data.
    print("getting basemap data...")
    #basedata_df = load("100k_filtered_OA_sample_cluster_and_positions.bz")
    basedata_df =pickle.load(open('100k_filtered_OA_sample_cluster_and_positions.pkl', 'rb'))
    print(basedata_df)
    return basedata_df


    
def setup_mapper():
    print("getting mapper...")

    params_new  = pickle.load(open('umap_mapper_300k_random_OA_specter_2_params.pkl', 'rb'))
    print("setting up mapper...")
    mapper = umap.UMAP()
    
    # Filter out 'target_backend' from umap_params if it exists
    umap_params = {k: v for k, v in params_new.get('umap_params', {}).items() if k != 'target_backend'}
    mapper.set_params(**umap_params)
    
    for attr, value in params_new.get('umap_attributes', {}).items():
        if attr != 'embedding_':
            setattr(mapper, attr, value)
    
    if 'embedding_' in params_new.get('umap_attributes', {}):
        mapper.embedding_ = List(params_new['umap_attributes']['embedding_'])
    
    return mapper


    

basedata_df = setup_basemap_data()
mapper = setup_mapper()


# mount Gradio app to FastAPI app
app = gr.mount_gradio_app(app, block, path="/")

# serve the app
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)

# run the app with
# python app.py
# or
# uvicorn "app:app" --host "0.0.0.0" --port 7860 --reload