import os #os.system("pip uninstall -y gradio") #os.system("pip install --upgrade gradio") #os.system("pip install datamapplot==0.3.0") #os.system("pip install numba==0.59.1") #os.system("pip install umap-learn==0.5.6") #os.system("pip install pynndescent==0.5.12") import spaces from pathlib import Path from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import uvicorn import gradio as gr from datetime import datetime import sys gr.set_static_paths(paths=["static/"]) # create a FastAPI app app = FastAPI() # create a static directory to store the static files static_dir = Path('./static') static_dir.mkdir(parents=True, exist_ok=True) # mount FastAPI StaticFiles server app.mount("/static", StaticFiles(directory=static_dir), name="static") # Gradio stuff import datamapplot import numpy as np import requests import io import pandas as pd from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders from itertools import chain from compress_pickle import load, dump from urllib.parse import urlparse, parse_qs import re from transformers import AutoTokenizer from adapters import AutoAdapterModel import torch from tqdm import tqdm from numba.typed import List import pickle import pynndescent import umap def openalex_url_to_pyalex_query(url): """ Convert an OpenAlex search URL to a pyalex query. Args: url (str): The OpenAlex search URL. Returns: tuple: (Works object, dict of parameters) """ parsed_url = urlparse(url) query_params = parse_qs(parsed_url.query) # Initialize the Works object query = Works() # Handle filters if 'filter' in query_params: filters = query_params['filter'][0].split(',') for f in filters: if ':' in f: key, value = f.split(':', 1) if key == 'default.search': query = query.search(value) else: query = query.filter(**{key: value}) # Handle sort if 'sort' in query_params: sort_params = query_params['sort'][0].split(',') for s in sort_params: if s.startswith('-'): query = query.sort(**{s[1:]: 'desc'}) else: query = query.sort(**{s: 'asc'}) # Handle other parameters params = {} for key in ['page', 'per-page', 'sample', 'seed']: if key in query_params: params[key] = query_params[key][0] return query, params def invert_abstract(inv_index): if inv_index is not None: l_inv = [(w, p) for w, pos in inv_index.items() for p in pos] return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1]))) else: return ' ' def get_pub(x): try: source = x['source']['display_name'] if source not in ['parsed_publication','Deleted Journal']: return source else: return ' ' except: return ' ' #def query_records(search_term): # # Fetch records based on the search term in the abstract! # query = Works().search([search_term]) # query_length = Works().search([search_term]).count() # records = [] # #total_pages = (query_length + 199) // 200 # Calculate total number of pages # progress=gr.Progress() # for i, record in progress.tqdm(enumerate(chain(*query.paginate(per_page=200)))): # records.append(record) # # Calculate progress from 0 to 0.1 # #achieved_progress = min(0.1, (i + 1) / query_length * 0.1) # # Update progress bar # #progress(achieved_progress, desc="Getting queried data...") # records_df = pd.DataFrame(records) # records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']] # records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']] # records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ') # records_df['abstract'] = records_df['abstract'].fillna(' ') # records_df['title'] = records_df['title'].fillna(' ') # return records_df ################# Setting up the model for specter2 embeddings ################### #device = torch.device("mps" if torch.backends.mps.is_available() else "cuda") #print(f"Using device: {device}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_aug2023refresh_base') model = AutoAdapterModel.from_pretrained('allenai/specter2_aug2023refresh_base') @spaces.GPU(duration=60) def create_embeddings(texts_to_embedd): # Set up the device print(len(texts_to_embedd)) # Load the proximity adapter and activate it model.load_adapter("allenai/specter2_aug2023refresh", source="hf", load_as="proximity", set_active=True) model.set_active_adapters("proximity") model.to(device) def batch_generator(data, batch_size): """Yield consecutive batches of data.""" for i in range(0, len(data), batch_size): yield data[i:i + batch_size] def encode_texts(texts, device, batch_size=16): """Process texts in batches and return their embeddings.""" model.eval() with torch.no_grad(): all_embeddings = [] count = 0 for batch in tqdm(batch_generator(texts, batch_size)): inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device) outputs = model(**inputs) embeddings = outputs.last_hidden_state[:, 0, :] # Taking the [CLS] token representation all_embeddings.append(embeddings.cpu()) # Move to CPU to free GPU memory #torch.mps.empty_cache() # Clear cache to free up memory if count == 100: #torch.mps.empty_cache() torch.cuda.empty_cache() count = 0 count +=1 all_embeddings = torch.cat(all_embeddings, dim=0) return all_embeddings # Concatenate title and abstract embeddings = encode_texts(texts_to_embedd, device, batch_size=32).cpu().numpy() # Process texts in batches of 10 return embeddings def predict(text_input, sample_size_slider, reduce_sample_checkbox, progress=gr.Progress()): print('getting data to project') progress(0, desc="Starting...") query, params = openalex_url_to_pyalex_query(text_input) query_length = query.count() print(f'Requesting {query_length} entries...') records = [] total_pages = (query_length + 199) // 200 # Calculate total number of pages per_page = 0.3 / total_pages for i, record in enumerate(chain(*query.paginate(per_page=200))): records.append(record) # Update progress bar progress(per_page * i, desc="Getting queried data...") records_df = pd.DataFrame(records) records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']] records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']] records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ') records_df['abstract'] = records_df['abstract'].fillna(' ') records_df['title'] = records_df['title'].fillna(' ') if reduce_sample_checkbox: records_df = records_df.sample(sample_size_slider) print(records_df) progress(0.3, desc="Embedding Data...") texts_to_embedd = [title + tokenizer.sep_token + publication + tokenizer.sep_token + abstract for title, publication, abstract in zip(records_df['title'],records_df['parsed_publication'], records_df['abstract'])] embeddings = create_embeddings(texts_to_embedd) print(embeddings) progress(0.5, desc="Project into UMAP-embedding...") umap_embeddings = mapper.transform(embeddings) records_df[['x','y']] = umap_embeddings basedata_df['color'] = '#ced4d211' records_df['color'] = '#f98e31' progress(0.6, desc="Set up data...") stacked_df = pd.concat([basedata_df,records_df], axis=0, ignore_index=True) stacked_df = stacked_df.fillna("Unlabelled") stacked_df = stacked_df.reset_index(drop=True) print(stacked_df) extra_data = pd.DataFrame(stacked_df['doi']) file_name = f"{datetime.utcnow().strftime('%s')}.html" file_path = static_dir / file_name print(file_path) # progress(0.7, desc="Plotting...") custom_css = """ #title-container { background: #edededaa; border-radius: 2px; box-shadow: 2px 3px 10px #aaaaaa00; } #search-container { position: fixed !important; top: 20px !important; right: 20px !important; left: auto !important; width: 200px !important; z-index: 9999 !important; } #search { // padding: 8px 8px !important; // border: none !important; // border-radius: 20px !important; background-color: #ffffffaa !important; font-family: 'Roboto Condensed', sans-serif !important; font-size: 14px; // box-shadow: 0 0px 0px #aaaaaa00 !important; } """ plot = datamapplot.create_interactive_plot( stacked_df[['x','y']].values, np.array(stacked_df['cluster_1_labels']),np.array(stacked_df['cluster_2_labels']),np.array(stacked_df['cluster_3_labels']), hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()], marker_color_array=stacked_df['color'], use_medoids=True, width=1000, height=1000, # title='The Science of Consciousness ', # sub_title=f'