m7n's picture
Update app.py
c2c5be0 verified
import os
import time
print(f"Starting up: {time.strftime('%Y-%m-%d %H:%M:%S')}")
#os.system("pip uninstall -y gradio")
#os.system("pip install --upgrade gradio")
#os.system("pip install datamapplot==0.3.0")
#os.system("pip install numba==0.59.1")
#os.system("pip install umap-learn==0.5.6")
#os.system("pip install pynndescent==0.5.12")
import spaces
from pathlib import Path
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn
import gradio as gr
from datetime import datetime
import sys
gr.set_static_paths(paths=["static/"])
# create a FastAPI app
app = FastAPI()
# create a static directory to store the static files
static_dir = Path('./static')
static_dir.mkdir(parents=True, exist_ok=True)
# mount FastAPI StaticFiles server
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# Gradio stuff
import datamapplot
import numpy as np
import requests
import io
import pandas as pd
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
from itertools import chain
from compress_pickle import load, dump
from urllib.parse import urlparse, parse_qs
import re
import pyalex
pyalex.config.email = "maximilian.noichl@uni-bamberg.de"
from transformers import AutoTokenizer
from adapters import AutoAdapterModel
import torch
from tqdm import tqdm
from numba.typed import List
import pickle
import pynndescent
import umap
print(f"Imports are done: {time.strftime('%Y-%m-%d %H:%M:%S')}")
def openalex_url_to_pyalex_query(url):
"""
Convert an OpenAlex search URL to a pyalex query.
Args:
url (str): The OpenAlex search URL.
Returns:
tuple: (Works object, dict of parameters)
"""
parsed_url = urlparse(url)
query_params = parse_qs(parsed_url.query)
# Initialize the Works object
query = Works()
# Handle filters
if 'filter' in query_params:
filters = query_params['filter'][0].split(',')
for f in filters:
if ':' in f:
key, value = f.split(':', 1)
if key == 'default.search':
query = query.search(value)
else:
query = query.filter(**{key: value})
# Handle sort
if 'sort' in query_params:
sort_params = query_params['sort'][0].split(',')
for s in sort_params:
if s.startswith('-'):
query = query.sort(**{s[1:]: 'desc'})
else:
query = query.sort(**{s: 'asc'})
# Handle other parameters
params = {}
for key in ['page', 'per-page', 'sample', 'seed']:
if key in query_params:
params[key] = query_params[key][0]
return query, params
def invert_abstract(inv_index):
if inv_index is not None:
l_inv = [(w, p) for w, pos in inv_index.items() for p in pos]
return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1])))
else:
return ' '
def get_pub(x):
try:
source = x['source']['display_name']
if source not in ['parsed_publication','Deleted Journal']:
return source
else:
return ' '
except:
return ' '
#def query_records(search_term):
# # Fetch records based on the search term in the abstract!
# query = Works().search([search_term])
# query_length = Works().search([search_term]).count()
# records = []
# #total_pages = (query_length + 199) // 200 # Calculate total number of pages
# progress=gr.Progress()
# for i, record in progress.tqdm(enumerate(chain(*query.paginate(per_page=200)))):
# records.append(record)
# # Calculate progress from 0 to 0.1
# #achieved_progress = min(0.1, (i + 1) / query_length * 0.1)
# # Update progress bar
# #progress(achieved_progress, desc="Getting queried data...")
# records_df = pd.DataFrame(records)
# records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']]
# records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']]
# records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ')
# records_df['abstract'] = records_df['abstract'].fillna(' ')
# records_df['title'] = records_df['title'].fillna(' ')
# return records_df
################# Setting up the model for specter2 embeddings ###################
print(f"Setting up language model: {time.strftime('%Y-%m-%d %H:%M:%S')}")
#device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
#print(f"Using device: {device}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_aug2023refresh_base')
model = AutoAdapterModel.from_pretrained('allenai/specter2_aug2023refresh_base')
@spaces.GPU(duration=60)
def create_embeddings(texts_to_embedd):
# Set up the device
print(len(texts_to_embedd))
# Load the proximity adapter and activate it
model.load_adapter("allenai/specter2_aug2023refresh", source="hf", load_as="proximity", set_active=True)
model.set_active_adapters("proximity")
model.to(device)
def batch_generator(data, batch_size):
"""Yield consecutive batches of data."""
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]
def encode_texts(texts, device, batch_size=16):
"""Process texts in batches and return their embeddings."""
model.eval()
with torch.no_grad():
all_embeddings = []
count = 0
for batch in tqdm(batch_generator(texts, batch_size)):
inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)
outputs = model(**inputs)
embeddings = outputs.last_hidden_state[:, 0, :] # Taking the [CLS] token representation
all_embeddings.append(embeddings.cpu()) # Move to CPU to free GPU memory
#torch.mps.empty_cache() # Clear cache to free up memory
if count == 100:
#torch.mps.empty_cache()
torch.cuda.empty_cache()
count = 0
count +=1
all_embeddings = torch.cat(all_embeddings, dim=0)
return all_embeddings
# Concatenate title and abstract
embeddings = encode_texts(texts_to_embedd, device, batch_size=32).cpu().numpy() # Process texts in batches of 10
return embeddings
print(f"Language model is set up: {time.strftime('%Y-%m-%d %H:%M:%S')}")
def predict(text_input, sample_size_slider, reduce_sample_checkbox,sample_reduction_method, progress=gr.Progress()):
print('getting data to project')
progress(0, desc="Starting...")
query, params = openalex_url_to_pyalex_query(text_input)
query_length = query.count()
print(f'Requesting {query_length} entries...')
records = []
for i, record in enumerate(chain(*query.paginate(per_page=200))):
records.append(record)
# Update progress bar
progress(0.3 * i / query_length, desc="Getting queried data...")
records_df = pd.DataFrame(records)
records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']]
records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']]
records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ')
records_df['abstract'] = records_df['abstract'].fillna(' ')
records_df['title'] = records_df['title'].fillna(' ')
if reduce_sample_checkbox:
sample_size = min(sample_size_slider, len(records_df))
if sample_reduction_method == "Random":
records_df = records_df.sample(sample_size)
elif sample_reduction_method == "Order of Results":
records_df = records_df.iloc[:sample_size]
print(records_df)
progress(0.3, desc="Embedding Data...")
texts_to_embedd = [title + tokenizer.sep_token + publication + tokenizer.sep_token + abstract for title, publication, abstract in zip(records_df['title'],records_df['parsed_publication'], records_df['abstract'])]
embeddings = create_embeddings(texts_to_embedd)
print(embeddings)
progress(0.5, desc="Project into UMAP-embedding...")
umap_embeddings = mapper.transform(embeddings)
records_df[['x','y']] = umap_embeddings
basedata_df['color'] = '#ced4d211'
records_df['color'] = '#f98e31'
progress(0.6, desc="Set up data...")
stacked_df = pd.concat([basedata_df,records_df], axis=0, ignore_index=True)
stacked_df = stacked_df.fillna("Unlabelled")
stacked_df = stacked_df.reset_index(drop=True)
print(stacked_df)
extra_data = pd.DataFrame(stacked_df['doi'])
file_name = f"{datetime.utcnow().strftime('%s')}.html"
file_path = static_dir / file_name
print(file_path)
#
progress(0.7, desc="Plotting...")
custom_css = """
#title-container {
background: #edededaa;
border-radius: 2px;
box-shadow: 2px 3px 10px #aaaaaa00;
}
#search-container {
position: fixed !important;
top: 20px !important;
right: 20px !important;
left: auto !important;
width: 200px !important;
z-index: 9999 !important;
}
#search {
// padding: 8px 8px !important;
// border: none !important;
// border-radius: 20px !important;
background-color: #ffffffaa !important;
font-family: 'Roboto Condensed', sans-serif !important;
font-size: 14px;
// box-shadow: 0 0px 0px #aaaaaa00 !important;
}
"""
plot = datamapplot.create_interactive_plot(
stacked_df[['x','y']].values,
np.array(stacked_df['cluster_1_labels']),np.array(stacked_df['cluster_2_labels']),np.array(stacked_df['cluster_3_labels']),
hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()],
marker_color_array=stacked_df['color'],
use_medoids=True,
width=1000,
height=1000,
# title='The Science of <span style="color:#ab0b00;"> Consciousness </span>',
# sub_title=f'<div style="margin-top:20px;"> Large sample, n={len(dataset_df_filtered)}, embeddings with specter 2 & UMAP, labels: Claude 3.5 Sonnet </div>',
point_radius_min_pixels=1,
text_outline_width=5,
point_hover_color='#5e2784',
point_radius_max_pixels=7,
color_label_text=False,
font_family="Roboto Condensed",
font_weight=700,
tooltip_font_weight=600,
tooltip_font_family="Roboto Condensed",
extra_point_data=extra_data,
on_click="window.open(`{doi}`)",
custom_css=custom_css,
initial_zoom_fraction=.8,
enable_search=True)
progress(0.9, desc="Saving plot...")
plot.save(file_path)
progress(1.0, desc="Done!")
iframe = f"""<iframe src="/static/{file_name}" width="100%" height="500px"></iframe>"""
link = f'<a href="/static/{file_name}" target="_blank">{file_name}</a>'
return link, iframe
################ MAIN BLOCK #####################
# with gr.Blocks() as block:
# gr.Markdown("""
# ## Mapping OpenAlex-Queries
# This is a tool to further interdisciplinary research – you are a neuroscientist who has used ..., What have the ... been doing with them.
# Your a philosopher of science who wonders where the concept of a fitnesslandscape has appeared...
# """)
# with gr.Row():
# with gr.Column():
# text_input = gr.Textbox(label="Name")
# markdown = gr.Markdown(label="Output Box")
# new_btn = gr.Button("New")
# with gr.Column():
# html = gr.HTML(label="HTML preview", show_label=True)
# new_btn.click(fn=predict, inputs=[text_input], outputs=[markdown, html])
with gr.Blocks() as block:
gr.Markdown("""
## Mapping OpenAlex-Queries
Enter the URL to an OpenAlex-search below. It will take a few minutes, but then the result will be projected onto a map of the OA database as a whole.
""")
# This is a tool to further interdisciplinary research – you are a neuroscientist who has used ..., What have the ... been doing with them.
# You're a philosopher of science who wonders where the concept of a fitness landscape has appeared...
with gr.Column():
text_input = gr.Textbox(label="OpenAlex-search URL")
with gr.Row():
reduce_sample_checkbox = gr.Checkbox(label="Reduce Sample Size", value=True, info="Reduce sample size.")
sample_size_slider = gr.Slider(label="Sample Size", minimum=10, maximum=20000, step=10, value=1000, info="How many samples to keep.")
sample_reduction_method = gr.Dropdown(["Order of Results", "Random"], label="Order of Results", info="How to choose the samples to keep.")
new_btn = gr.Button("Run Query",variant='primary')
markdown = gr.Markdown(label="")
html = gr.HTML(label="HTML preview", show_label=True)
new_btn.click(fn=predict, inputs=[text_input, sample_size_slider, reduce_sample_checkbox,sample_reduction_method], outputs=[markdown, html])
def setup_basemap_data():
# get data.
print(f"Getting basemap data: {time.strftime('%Y-%m-%d %H:%M:%S')}")
#basedata_df = load("100k_filtered_OA_sample_cluster_and_positions.bz")
basedata_df =pickle.load(open('100k_filtered_OA_sample_cluster_and_positions.pkl', 'rb'))
print(basedata_df)
return basedata_df
def setup_mapper():
print(f"Getting Mapper: {time.strftime('%Y-%m-%d %H:%M:%S')}")
params_new = pickle.load(open('umap_mapper_300k_random_OA_specter_2_params.pkl', 'rb'))
print("setting up mapper...")
mapper = umap.UMAP()
# Filter out 'target_backend' from umap_params if it exists
umap_params = {k: v for k, v in params_new.get('umap_params', {}).items() if k != 'target_backend'}
mapper.set_params(**umap_params)
for attr, value in params_new.get('umap_attributes', {}).items():
if attr != 'embedding_':
setattr(mapper, attr, value)
if 'embedding_' in params_new.get('umap_attributes', {}):
mapper.embedding_ = List(params_new['umap_attributes']['embedding_'])
return mapper
basedata_df = setup_basemap_data()
mapper = setup_mapper()
print(f"Setup done, starting up app: {time.strftime('%Y-%m-%d %H:%M:%S')}")
# mount Gradio app to FastAPI app
app = gr.mount_gradio_app(app, block, path="/")
# serve the app
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
# run the app with
# python app.py
# or
# uvicorn "app:app" --host "0.0.0.0" --port 7860 --reload