m7n commited on
Commit
a0eb5f6
·
verified ·
1 Parent(s): 8c642e5

Update app.py

Browse files

added spectral 2 first attempt

Files changed (1) hide show
  1. app.py +89 -0
app.py CHANGED
@@ -44,6 +44,16 @@ from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers,
44
  from itertools import chain
45
  from compress_pickle import load, dump
46
 
 
 
 
 
 
 
 
 
 
 
47
  def query_records(search_term):
48
  def invert_abstract(inv_index):
49
  if inv_index is not None:
@@ -67,6 +77,67 @@ def query_records(search_term):
67
 
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def predict(text_input, progress=gr.Progress()):
71
 
72
  # get data.
@@ -75,6 +146,24 @@ def predict(text_input, progress=gr.Progress()):
75
 
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  file_name = f"{datetime.utcnow().strftime('%s')}.html"
79
  file_path = static_dir / file_name
80
  print(file_path)
 
44
  from itertools import chain
45
  from compress_pickle import load, dump
46
 
47
+
48
+
49
+ from transformers import AutoTokenizer
50
+ from adapters import AutoAdapterModel
51
+ import torch
52
+ from tqdm import tqdm
53
+
54
+
55
+
56
+
57
  def query_records(search_term):
58
  def invert_abstract(inv_index):
59
  if inv_index is not None:
 
77
 
78
 
79
 
80
+ ################# Setting up the model for specter2 embeddings ###################
81
+
82
+
83
+
84
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
85
+ print(f"Using device: {device}")
86
+
87
+ tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_aug2023refresh_base')
88
+ model = AutoAdapterModel.from_pretrained('allenai/specter2_aug2023refresh_base')
89
+
90
+
91
+
92
+ def create_embeddings(texts_to_embedd):
93
+ # Set up the device
94
+
95
+
96
+ print(len(texts_to_embedd))
97
+
98
+ # Load the proximity adapter and activate it
99
+ model.load_adapter("allenai/specter2_aug2023refresh", source="hf", load_as="proximity", set_active=True)
100
+ model.set_active_adapters("proximity")
101
+
102
+ model.to(device)
103
+
104
+ def batch_generator(data, batch_size):
105
+ """Yield consecutive batches of data."""
106
+ for i in range(0, len(data), batch_size):
107
+ yield data[i:i + batch_size]
108
+
109
+ @spaces.GPU(duration=120)
110
+ def encode_texts(texts, device, batch_size=16):
111
+ """Process texts in batches and return their embeddings."""
112
+ model.eval()
113
+ with torch.no_grad():
114
+ all_embeddings = []
115
+ count = 0
116
+ for batch in tqdm(batch_generator(texts, batch_size)):
117
+ inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)
118
+ outputs = model(**inputs)
119
+ embeddings = outputs.last_hidden_state[:, 0, :] # Taking the [CLS] token representation
120
+
121
+ all_embeddings.append(embeddings.cpu()) # Move to CPU to free GPU memory
122
+ #torch.mps.empty_cache() # Clear cache to free up memory
123
+ if count == 100:
124
+ torch.mps.empty_cache()
125
+ count = 0
126
+
127
+ count +=1
128
+
129
+ all_embeddings = torch.cat(all_embeddings, dim=0)
130
+ return all_embeddings
131
+
132
+ # Concatenate title and abstract
133
+ embeddings = encode_texts(texts_to_embedd, device, batch_size=32).cpu().numpy() # Process texts in batches of 10
134
+
135
+ return embeddings
136
+
137
+
138
+
139
+
140
+
141
  def predict(text_input, progress=gr.Progress()):
142
 
143
  # get data.
 
146
 
147
 
148
 
149
+ texts_to_embedd = [title + tokenizer.sep_token + publication + tokenizer.sep_token + abstract for title, publication, abstract in zip(records_df['title'],records_df['parsed_publication'], records_df['abstract'])]
150
+
151
+ embeddings = create_embeddings(texts_to_embedd)
152
+ print(embeddings)
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
+
165
+
166
+
167
  file_name = f"{datetime.utcnow().strftime('%s')}.html"
168
  file_path = static_dir / file_name
169
  print(file_path)