Spaces:
Build error
Build error
import os | |
os.system("pip uninstall -y gradio") | |
os.system("pip install --upgrade gradio") | |
import spaces | |
from pathlib import Path | |
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
import uvicorn | |
import gradio as gr | |
from datetime import datetime | |
import sys | |
gr.set_static_paths(paths=["static/"]) | |
# create a FastAPI app | |
app = FastAPI() | |
# create a static directory to store the static files | |
static_dir = Path('./static') | |
static_dir.mkdir(parents=True, exist_ok=True) | |
# mount FastAPI StaticFiles server | |
app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
# Gradio stuff | |
import datamapplot | |
import numpy as np | |
import requests | |
import io | |
import pandas as pd | |
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders | |
from itertools import chain | |
from compress_pickle import load, dump | |
from transformers import AutoTokenizer | |
from adapters import AutoAdapterModel | |
import torch | |
from tqdm import tqdm | |
def query_records(search_term): | |
def invert_abstract(inv_index): | |
if inv_index is not None: | |
l_inv = [(w, p) for w, pos in inv_index.items() for p in pos] | |
return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1]))) | |
else: | |
return ' ' | |
def get_pub(x): | |
try: | |
source = x['source']['display_name'] | |
if source not in ['parsed_publication','Deleted Journal']: | |
return source | |
else: | |
return ' ' | |
except: | |
return ' ' | |
# Fetch records based on the search term | |
query = Works().search_filter(abstract=search_term) | |
records = [] | |
for record in chain(*query.paginate(per_page=200)): | |
records.append(record) | |
records_df = pd.DataFrame(records) | |
records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']] | |
records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']] | |
return records_df | |
################# Setting up the model for specter2 embeddings ################### | |
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda") | |
print(f"Using device: {device}") | |
tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_aug2023refresh_base') | |
model = AutoAdapterModel.from_pretrained('allenai/specter2_aug2023refresh_base') | |
def create_embeddings(texts_to_embedd): | |
# Set up the device | |
print(len(texts_to_embedd)) | |
# Load the proximity adapter and activate it | |
model.load_adapter("allenai/specter2_aug2023refresh", source="hf", load_as="proximity", set_active=True) | |
model.set_active_adapters("proximity") | |
model.to(device) | |
def batch_generator(data, batch_size): | |
"""Yield consecutive batches of data.""" | |
for i in range(0, len(data), batch_size): | |
yield data[i:i + batch_size] | |
def encode_texts(texts, device, batch_size=16): | |
"""Process texts in batches and return their embeddings.""" | |
model.eval() | |
with torch.no_grad(): | |
all_embeddings = [] | |
count = 0 | |
for batch in tqdm(batch_generator(texts, batch_size)): | |
inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device) | |
outputs = model(**inputs) | |
embeddings = outputs.last_hidden_state[:, 0, :] # Taking the [CLS] token representation | |
all_embeddings.append(embeddings.cpu()) # Move to CPU to free GPU memory | |
#torch.mps.empty_cache() # Clear cache to free up memory | |
if count == 100: | |
torch.mps.empty_cache() | |
count = 0 | |
count +=1 | |
all_embeddings = torch.cat(all_embeddings, dim=0) | |
return all_embeddings | |
# Concatenate title and abstract | |
embeddings = encode_texts(texts_to_embedd, device, batch_size=32).cpu().numpy() # Process texts in batches of 10 | |
return embeddings | |
def predict(text_input, progress=gr.Progress()): | |
# get data. | |
records_df = query_records(text_input) | |
print(records_df) | |
texts_to_embedd = [title + tokenizer.sep_token + publication + tokenizer.sep_token + abstract for title, publication, abstract in zip(records_df['title'],records_df['parsed_publication'], records_df['abstract'])] | |
embeddings = create_embeddings(texts_to_embedd) | |
print(embeddings) | |
file_name = f"{datetime.utcnow().strftime('%s')}.html" | |
file_path = static_dir / file_name | |
print(file_path) | |
# | |
progress(0.7, desc="Loading hover data...") | |
plot = datamapplot.create_interactive_plot( | |
basedata_df[['x','y']].values, | |
np.array(basedata_df['cluster_1_labels']), | |
hover_text=[str(ix) + ', ' + str(row['parsed_publication']) + str(row['title']) for ix, row in basedata_df.iterrows()], | |
font_family="Roboto Condensed", | |
) | |
progress(0.9, desc="Saving plot...") | |
plot.save(file_path) | |
progress(1.0, desc="Done!") | |
iframe = f"""<iframe src="/static/{file_name}" width="100%" height="500px"></iframe>""" | |
link = f'<a href="/static/{file_name}" target="_blank">{file_name}</a>' | |
return link, iframe | |
with gr.Blocks() as block: | |
gr.Markdown(""" | |
## Gradio + FastAPI + Static Server | |
This is a demo of how to use Gradio with FastAPI and a static server. | |
The Gradio app generates dynamic HTML files and stores them in a static directory. FastAPI serves the static files. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox(label="Name") | |
markdown = gr.Markdown(label="Output Box") | |
new_btn = gr.Button("New") | |
with gr.Column(): | |
html = gr.HTML(label="HTML preview", show_label=True) | |
new_btn.click(fn=predict, inputs=[text_input], outputs=[markdown, html]) | |
def setup_basemap_data(): | |
# get data. | |
print("getting basemap data...") | |
basedata_file= requests.get( | |
"https://www.maxnoichl.eu/full/oa_project_on_scimap_background_data/100k_filtered_OA_sample_cluster_and_positions.bz" | |
) | |
# Write the response content to a .bz file in the static directory | |
static_dir = Path("static") | |
static_dir.mkdir(exist_ok=True) | |
bz_file_name = "100k_filtered_OA_sample_cluster_and_positions.bz" | |
bz_file_path = static_dir / bz_file_name | |
with open(bz_file_path, "wb") as f: | |
f.write(basedata_file.content) | |
# Load the data from the saved .bz file | |
basedata_df = load(bz_file_path) | |
return basedata_df | |
basedata_df = setup_basemap_data() | |
# mount Gradio app to FastAPI app | |
app = gr.mount_gradio_app(app, block, path="/") | |
# serve the app | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |
# run the app with | |
# python app.py | |
# or | |
# uvicorn "app:app" --host "0.0.0.0" --port 7860 --reload | |