File size: 7,570 Bytes
985ef3e
 
 
140d3a9
abf93bc
 
 
 
 
985ef3e
3a03f47
985ef3e
 
22f7397
 
 
 
 
 
7bf4167
22f7397
 
7c5f9b0
 
 
 
 
22f7397
 
 
 
 
 
 
 
 
 
7bf4167
 
 
fc74e8d
 
 
 
 
955747f
 
 
 
 
a0eb5f6
 
 
 
 
 
a2d12d3
 
 
 
a0eb5f6
 
 
 
184fac5
 
 
 
 
 
 
 
 
 
955747f
 
 
 
 
 
 
 
184fac5
 
 
 
 
 
 
 
 
 
955747f
 
 
 
 
 
 
 
 
184fac5
 
955747f
 
 
fc74e8d
 
 
a0eb5f6
 
 
 
 
 
 
 
 
 
 
4419e34
a0eb5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4419e34
a0eb5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
955747f
 
 
 
 
 
 
 
a0eb5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22f7397
 
 
fc74e8d
955747f
fc74e8d
955747f
05560e1
fc74e8d
955747f
 
 
05560e1
 
fc74e8d
955747f
05560e1
fc74e8d
955747f
7bf4167
 
 
22f7397
 
7bf4167
 
 
 
 
 
 
 
 
 
 
 
 
 
22f7397
955747f
 
 
 
 
611f1cc
53c4d50
955747f
 
4566fdb
 
53c4d50
 
 
611f1cc
53c4d50
 
4566fdb
 
 
 
 
 
 
 
 
 
 
 
53c4d50
 
 
4566fdb
 
955747f
53c4d50
955747f
 
22f7397
 
 
 
 
 
7bf4167
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import os
os.system("pip uninstall -y gradio")
os.system("pip install --upgrade gradio")
os.system("pip install datamapplot==0.3.0")
os.system("pip install numba==0.59.1")
os.system("pip install umap-learn==0.5.6")
os.system("pip install pynndescent==0.5.12")



import spaces


from pathlib import Path
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn
import gradio as gr
from datetime import datetime
import sys



gr.set_static_paths(paths=["static/"])



# create a FastAPI app
app = FastAPI()

# create a static directory to store the static files
static_dir = Path('./static')
static_dir.mkdir(parents=True, exist_ok=True)

# mount FastAPI StaticFiles server
app.mount("/static", StaticFiles(directory=static_dir), name="static")

# Gradio stuff



import datamapplot
import numpy as np
import requests
import io
import pandas as pd
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
from itertools import chain
from compress_pickle import load, dump



from transformers import AutoTokenizer
from adapters import AutoAdapterModel
import torch
from tqdm import tqdm
from numba.typed import List
import pickle
import pynndescent
import umap














def query_records(search_term):
    def invert_abstract(inv_index):
        if inv_index is not None:
            l_inv = [(w, p) for w, pos in inv_index.items() for p in pos]
            return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1])))
        else:
            return ' '

    def get_pub(x):
        try: 
            source = x['source']['display_name']
            if source not in ['parsed_publication','Deleted Journal']:
                return source
            else: 
                return ' '
        except:
            return ' '

    # Fetch records based on the search term
    query = Works().search_filter(abstract=search_term)

    records = []
    for record in chain(*query.paginate(per_page=200)):
        records.append(record)

    records_df = pd.DataFrame(records)
    records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']]
    records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']]

    
    return records_df




################# Setting up the model for specter2 embeddings ###################



device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_aug2023refresh_base')
model = AutoAdapterModel.from_pretrained('allenai/specter2_aug2023refresh_base')


@spaces.GPU(duration=120)
def create_embeddings(texts_to_embedd):
    # Set up the device


    print(len(texts_to_embedd))

    # Load the proximity adapter and activate it
    model.load_adapter("allenai/specter2_aug2023refresh", source="hf", load_as="proximity", set_active=True)
    model.set_active_adapters("proximity")
 
    model.to(device)

    def batch_generator(data, batch_size):
        """Yield consecutive batches of data."""
        for i in range(0, len(data), batch_size):
            yield data[i:i + batch_size]

    
    def encode_texts(texts, device, batch_size=16):
        """Process texts in batches and return their embeddings."""
        model.eval()
        with torch.no_grad():
            all_embeddings = []
            count = 0
            for batch in tqdm(batch_generator(texts, batch_size)):
                inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)
                outputs = model(**inputs)
                embeddings = outputs.last_hidden_state[:, 0, :]  # Taking the [CLS] token representation
                
                all_embeddings.append(embeddings.cpu())  # Move to CPU to free GPU memory
                #torch.mps.empty_cache()  # Clear cache to free up memory
                if count == 100:
                    torch.mps.empty_cache()
                    count = 0
                
                count +=1
            
            all_embeddings = torch.cat(all_embeddings, dim=0)
        return all_embeddings

    # Concatenate title and abstract
    embeddings = encode_texts(texts_to_embedd, device, batch_size=32).cpu().numpy()  # Process texts in batches of 10

    return embeddings





def predict(text_input, progress=gr.Progress()):
    
    # get data.
    records_df = query_records(text_input)
    print(records_df)
    
    
    
    texts_to_embedd = [title + tokenizer.sep_token + publication + tokenizer.sep_token  + abstract for title, publication, abstract in zip(records_df['title'],records_df['parsed_publication'], records_df['abstract'])]
    
    embeddings = create_embeddings(texts_to_embedd)
    print(embeddings)
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    file_name = f"{datetime.utcnow().strftime('%s')}.html"
    file_path = static_dir / file_name
    print(file_path)
 
   #
    
    progress(0.7, desc="Loading hover data...")

    plot = datamapplot.create_interactive_plot(
        basedata_df[['x','y']].values,
        np.array(basedata_df['cluster_1_labels']),
        hover_text=[str(ix) + ', ' + str(row['parsed_publication']) + str(row['title']) for ix, row in basedata_df.iterrows()],
        font_family="Roboto Condensed",
    )

    progress(0.9, desc="Saving plot...")
    plot.save(file_path)

    progress(1.0, desc="Done!")
    iframe = f"""<iframe src="/static/{file_name}" width="100%" height="500px"></iframe>"""
    link = f'<a href="/static/{file_name}" target="_blank">{file_name}</a>'
    return link, iframe

with gr.Blocks() as block:
    gr.Markdown("""
## Gradio + FastAPI + Static Server
This is a demo of how to use Gradio with FastAPI and a static server.
The Gradio app generates dynamic HTML files and stores them in a static directory. FastAPI serves the static files.
""")
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(label="Name")
            markdown = gr.Markdown(label="Output Box")
            new_btn = gr.Button("New")
        with gr.Column():
            html = gr.HTML(label="HTML preview", show_label=True)

    new_btn.click(fn=predict, inputs=[text_input], outputs=[markdown, html])



def setup_basemap_data():
    # get data.
    print("getting basemap data...")
    basedata_df = load("100k_filtered_OA_sample_cluster_and_positions.bz")
    print(basedata_df)
    return basedata_df


    
def setup_mapper():
    print("getting mapper...")

    params_new  = pickle.load(open('umap_mapper_300k_random_OA_specter_2_params.pkl', 'rb'))
    print("setting up mapper...")
    mapper = umap.UMAP()
    
    # Filter out 'target_backend' from umap_params if it exists
    umap_params = {k: v for k, v in params_new.get('umap_params', {}).items() if k != 'target_backend'}
    mapper.set_params(**umap_params)
    
    for attr, value in params_new.get('umap_attributes', {}).items():
        if attr != 'embedding_':
            setattr(mapper, attr, value)
    
    if 'embedding_' in params_new.get('umap_attributes', {}):
        mapper.embedding_ = List(params_new['umap_attributes']['embedding_'])
    
    return mapper


    

basedata_df = setup_basemap_data()
mapper = setup_mapper()


# mount Gradio app to FastAPI app
app = gr.mount_gradio_app(app, block, path="/")

# serve the app
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)

# run the app with
# python app.py
# or
# uvicorn "app:app" --host "0.0.0.0" --port 7860 --reload