Spaces:
Build error
Build error
import os | |
#os.system("pip uninstall -y gradio") | |
#os.system("pip install --upgrade gradio") | |
#os.system("pip install datamapplot==0.3.0") | |
#os.system("pip install numba==0.59.1") | |
#os.system("pip install umap-learn==0.5.6") | |
#os.system("pip install pynndescent==0.5.12") | |
import spaces | |
from pathlib import Path | |
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
import uvicorn | |
import gradio as gr | |
from datetime import datetime | |
import sys | |
gr.set_static_paths(paths=["static/"]) | |
# create a FastAPI app | |
app = FastAPI() | |
# create a static directory to store the static files | |
static_dir = Path('./static') | |
static_dir.mkdir(parents=True, exist_ok=True) | |
# mount FastAPI StaticFiles server | |
app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
# Gradio stuff | |
import datamapplot | |
import numpy as np | |
import requests | |
import io | |
import pandas as pd | |
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders | |
from itertools import chain | |
from compress_pickle import load, dump | |
from urllib.parse import urlparse, parse_qs | |
import re | |
from transformers import AutoTokenizer | |
from adapters import AutoAdapterModel | |
import torch | |
from tqdm import tqdm | |
from numba.typed import List | |
import pickle | |
import pynndescent | |
import umap | |
def openalex_url_to_pyalex_query(url): | |
""" | |
Convert an OpenAlex search URL to a pyalex query. | |
Args: | |
url (str): The OpenAlex search URL. | |
Returns: | |
tuple: (Works object, dict of parameters) | |
""" | |
parsed_url = urlparse(url) | |
query_params = parse_qs(parsed_url.query) | |
# Initialize the Works object | |
query = Works() | |
# Handle filters | |
if 'filter' in query_params: | |
filters = query_params['filter'][0].split(',') | |
for f in filters: | |
if ':' in f: | |
key, value = f.split(':', 1) | |
if key == 'default.search': | |
query = query.search(value) | |
else: | |
query = query.filter(**{key: value}) | |
# Handle sort | |
if 'sort' in query_params: | |
sort_params = query_params['sort'][0].split(',') | |
for s in sort_params: | |
if s.startswith('-'): | |
query = query.sort(**{s[1:]: 'desc'}) | |
else: | |
query = query.sort(**{s: 'asc'}) | |
# Handle other parameters | |
params = {} | |
for key in ['page', 'per-page', 'sample', 'seed']: | |
if key in query_params: | |
params[key] = query_params[key][0] | |
return query, params | |
def invert_abstract(inv_index): | |
if inv_index is not None: | |
l_inv = [(w, p) for w, pos in inv_index.items() for p in pos] | |
return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1]))) | |
else: | |
return ' ' | |
def get_pub(x): | |
try: | |
source = x['source']['display_name'] | |
if source not in ['parsed_publication','Deleted Journal']: | |
return source | |
else: | |
return ' ' | |
except: | |
return ' ' | |
#def query_records(search_term): | |
# # Fetch records based on the search term in the abstract! | |
# query = Works().search([search_term]) | |
# query_length = Works().search([search_term]).count() | |
# records = [] | |
# #total_pages = (query_length + 199) // 200 # Calculate total number of pages | |
# progress=gr.Progress() | |
# for i, record in progress.tqdm(enumerate(chain(*query.paginate(per_page=200)))): | |
# records.append(record) | |
# # Calculate progress from 0 to 0.1 | |
# #achieved_progress = min(0.1, (i + 1) / query_length * 0.1) | |
# # Update progress bar | |
# #progress(achieved_progress, desc="Getting queried data...") | |
# records_df = pd.DataFrame(records) | |
# records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']] | |
# records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']] | |
# records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ') | |
# records_df['abstract'] = records_df['abstract'].fillna(' ') | |
# records_df['title'] = records_df['title'].fillna(' ') | |
# return records_df | |
################# Setting up the model for specter2 embeddings ################### | |
#device = torch.device("mps" if torch.backends.mps.is_available() else "cuda") | |
#print(f"Using device: {device}") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_aug2023refresh_base') | |
model = AutoAdapterModel.from_pretrained('allenai/specter2_aug2023refresh_base') | |
def create_embeddings(texts_to_embedd): | |
# Set up the device | |
print(len(texts_to_embedd)) | |
# Load the proximity adapter and activate it | |
model.load_adapter("allenai/specter2_aug2023refresh", source="hf", load_as="proximity", set_active=True) | |
model.set_active_adapters("proximity") | |
model.to(device) | |
def batch_generator(data, batch_size): | |
"""Yield consecutive batches of data.""" | |
for i in range(0, len(data), batch_size): | |
yield data[i:i + batch_size] | |
def encode_texts(texts, device, batch_size=16): | |
"""Process texts in batches and return their embeddings.""" | |
model.eval() | |
with torch.no_grad(): | |
all_embeddings = [] | |
count = 0 | |
for batch in tqdm(batch_generator(texts, batch_size)): | |
inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device) | |
outputs = model(**inputs) | |
embeddings = outputs.last_hidden_state[:, 0, :] # Taking the [CLS] token representation | |
all_embeddings.append(embeddings.cpu()) # Move to CPU to free GPU memory | |
#torch.mps.empty_cache() # Clear cache to free up memory | |
if count == 100: | |
#torch.mps.empty_cache() | |
torch.cuda.empty_cache() | |
count = 0 | |
count +=1 | |
all_embeddings = torch.cat(all_embeddings, dim=0) | |
return all_embeddings | |
# Concatenate title and abstract | |
embeddings = encode_texts(texts_to_embedd, device, batch_size=32).cpu().numpy() # Process texts in batches of 10 | |
return embeddings | |
def predict(text_input, sample_size_slider, reduce_sample_checkbox, progress=gr.Progress()): | |
print('getting data to project') | |
progress(0, desc="Starting...") | |
query, params = openalex_url_to_pyalex_query(text_input) | |
query_length = query.count() | |
print(f'Requesting {query_length} entries...') | |
records = [] | |
total_pages = (query_length + 199) // 200 # Calculate total number of pages | |
per_page = 0.3 / total_pages | |
for i, record in enumerate(chain(*query.paginate(per_page=200))): | |
records.append(record) | |
# Update progress bar | |
progress(per_page * i, desc="Getting queried data...") | |
records_df = pd.DataFrame(records) | |
records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']] | |
records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']] | |
records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ') | |
records_df['abstract'] = records_df['abstract'].fillna(' ') | |
records_df['title'] = records_df['title'].fillna(' ') | |
if reduce_sample_checkbox: | |
records_df = records_df.sample(sample_size_slider) | |
print(records_df) | |
progress(0.3, desc="Embedding Data...") | |
texts_to_embedd = [title + tokenizer.sep_token + publication + tokenizer.sep_token + abstract for title, publication, abstract in zip(records_df['title'],records_df['parsed_publication'], records_df['abstract'])] | |
embeddings = create_embeddings(texts_to_embedd) | |
print(embeddings) | |
progress(0.5, desc="Project into UMAP-embedding...") | |
umap_embeddings = mapper.transform(embeddings) | |
records_df[['x','y']] = umap_embeddings | |
basedata_df['color'] = '#ced4d211' | |
records_df['color'] = '#f98e31' | |
progress(0.6, desc="Set up data...") | |
stacked_df = pd.concat([basedata_df,records_df], axis=0, ignore_index=True) | |
stacked_df = stacked_df.fillna("Unlabelled") | |
stacked_df = stacked_df.reset_index(drop=True) | |
print(stacked_df) | |
extra_data = pd.DataFrame(stacked_df['doi']) | |
file_name = f"{datetime.utcnow().strftime('%s')}.html" | |
file_path = static_dir / file_name | |
print(file_path) | |
# | |
progress(0.7, desc="Plotting...") | |
custom_css = """ | |
#title-container { | |
background: #edededaa; | |
border-radius: 2px; | |
box-shadow: 2px 3px 10px #aaaaaa00; | |
} | |
#search-container { | |
position: fixed !important; | |
top: 20px !important; | |
right: 20px !important; | |
left: auto !important; | |
width: 200px !important; | |
z-index: 9999 !important; | |
} | |
#search { | |
// padding: 8px 8px !important; | |
// border: none !important; | |
// border-radius: 20px !important; | |
background-color: #ffffffaa !important; | |
font-family: 'Roboto Condensed', sans-serif !important; | |
font-size: 14px; | |
// box-shadow: 0 0px 0px #aaaaaa00 !important; | |
} | |
""" | |
plot = datamapplot.create_interactive_plot( | |
stacked_df[['x','y']].values, | |
np.array(stacked_df['cluster_1_labels']),np.array(stacked_df['cluster_2_labels']),np.array(stacked_df['cluster_3_labels']), | |
hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()], | |
marker_color_array=stacked_df['color'], | |
use_medoids=True, | |
width=1000, | |
height=1000, | |
# title='The Science of <span style="color:#ab0b00;"> Consciousness </span>', | |
# sub_title=f'<div style="margin-top:20px;"> Large sample, n={len(dataset_df_filtered)}, embeddings with specter 2 & UMAP, labels: Claude 3.5 Sonnet </div>', | |
point_radius_min_pixels=1, | |
text_outline_width=5, | |
point_hover_color='#5e2784', | |
point_radius_max_pixels=7, | |
color_label_text=False, | |
font_family="Roboto Condensed", | |
font_weight=700, | |
tooltip_font_weight=600, | |
tooltip_font_family="Roboto Condensed", | |
extra_point_data=extra_data, | |
on_click="window.open(`{doi}`)", | |
custom_css=custom_css, | |
initial_zoom_fraction=.8, | |
enable_search=True) | |
progress(0.9, desc="Saving plot...") | |
plot.save(file_path) | |
progress(1.0, desc="Done!") | |
iframe = f"""<iframe src="/static/{file_name}" width="100%" height="500px"></iframe>""" | |
link = f'<a href="/static/{file_name}" target="_blank">{file_name}</a>' | |
return link, iframe | |
################ MAIN BLOCK ##################### | |
# with gr.Blocks() as block: | |
# gr.Markdown(""" | |
# ## Mapping OpenAlex-Queries | |
# This is a tool to further interdisciplinary research – you are a neuroscientist who has used ..., What have the ... been doing with them. | |
# Your a philosopher of science who wonders where the concept of a fitnesslandscape has appeared... | |
# """) | |
# with gr.Row(): | |
# with gr.Column(): | |
# text_input = gr.Textbox(label="Name") | |
# markdown = gr.Markdown(label="Output Box") | |
# new_btn = gr.Button("New") | |
# with gr.Column(): | |
# html = gr.HTML(label="HTML preview", show_label=True) | |
# new_btn.click(fn=predict, inputs=[text_input], outputs=[markdown, html]) | |
with gr.Blocks() as block: | |
gr.Markdown(""" | |
## Mapping OpenAlex-Queries | |
Enter the URL to an OpenAlex-search below. It will take a few minutes, but then the result will be projected onto a map of the OA database as a whole. | |
""") | |
# This is a tool to further interdisciplinary research – you are a neuroscientist who has used ..., What have the ... been doing with them. | |
# You're a philosopher of science who wonders where the concept of a fitness landscape has appeared... | |
with gr.Column(): | |
text_input = gr.Textbox(label="OpenAlex Fulltext-Search") | |
sample_size_slider = gr.Slider(label="Sample Size", minimum=10, maximum=20000, step=10, value=1000) | |
reduce_sample_checkbox = gr.Checkbox(label="Reduce Sample Size", value=True) | |
new_btn = gr.Button("Run Query") | |
markdown = gr.Markdown(label="") | |
html = gr.HTML(label="HTML preview", show_label=True) | |
new_btn.click(fn=predict, inputs=[text_input, sample_size_slider, reduce_sample_checkbox], outputs=[markdown, html]) | |
def setup_basemap_data(): | |
# get data. | |
print("getting basemap data...") | |
#basedata_df = load("100k_filtered_OA_sample_cluster_and_positions.bz") | |
base_data_df =pickle.load(open('100k_filtered_OA_sample_cluster_and_positions.pkl', 'rb')) | |
print(basedata_df) | |
return basedata_df | |
def setup_mapper(): | |
print("getting mapper...") | |
params_new = pickle.load(open('umap_mapper_300k_random_OA_specter_2_params.pkl', 'rb')) | |
print("setting up mapper...") | |
mapper = umap.UMAP() | |
# Filter out 'target_backend' from umap_params if it exists | |
umap_params = {k: v for k, v in params_new.get('umap_params', {}).items() if k != 'target_backend'} | |
mapper.set_params(**umap_params) | |
for attr, value in params_new.get('umap_attributes', {}).items(): | |
if attr != 'embedding_': | |
setattr(mapper, attr, value) | |
if 'embedding_' in params_new.get('umap_attributes', {}): | |
mapper.embedding_ = List(params_new['umap_attributes']['embedding_']) | |
return mapper | |
basedata_df = setup_basemap_data() | |
mapper = setup_mapper() | |
# mount Gradio app to FastAPI app | |
app = gr.mount_gradio_app(app, block, path="/") | |
# serve the app | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |
# run the app with | |
# python app.py | |
# or | |
# uvicorn "app:app" --host "0.0.0.0" --port 7860 --reload | |