m7n commited on
Commit
4ea4863
·
verified ·
1 Parent(s): feb835a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -72,7 +72,7 @@ import umap
72
 
73
 
74
 
75
- def query_records(search_term,progress):
76
  def invert_abstract(inv_index):
77
  if inv_index is not None:
78
  l_inv = [(w, p) for w, pos in inv_index.items() for p in pos]
@@ -95,16 +95,17 @@ def query_records(search_term,progress):
95
  query_length = Works().search([search_term]).count()
96
 
97
  records = []
98
- total_pages = (query_length + 199) // 200 # Calculate total number of pages
 
99
 
100
- for i, record in enumerate(chain(*query.paginate(per_page=200))):
101
  records.append(record)
102
 
103
  # Calculate progress from 0 to 0.1
104
- achieved_progress = min(0.1, (i + 1) / query_length * 0.1)
105
 
106
  # Update progress bar
107
- progress(achieved_progress, desc="Getting queried data...")
108
 
109
 
110
 
@@ -128,7 +129,9 @@ def query_records(search_term,progress):
128
 
129
 
130
 
131
- device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
 
 
132
  print(f"Using device: {device}")
133
 
134
  tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_aug2023refresh_base')
@@ -168,7 +171,8 @@ def create_embeddings(texts_to_embedd):
168
  all_embeddings.append(embeddings.cpu()) # Move to CPU to free GPU memory
169
  #torch.mps.empty_cache() # Clear cache to free up memory
170
  if count == 100:
171
- torch.mps.empty_cache()
 
172
  count = 0
173
 
174
  count +=1
 
72
 
73
 
74
 
75
+ def query_records(search_term):
76
  def invert_abstract(inv_index):
77
  if inv_index is not None:
78
  l_inv = [(w, p) for w, pos in inv_index.items() for p in pos]
 
95
  query_length = Works().search([search_term]).count()
96
 
97
  records = []
98
+ #total_pages = (query_length + 199) // 200 # Calculate total number of pages
99
+ progress=gr.Progress()
100
 
101
+ for i, record in progress.tqdm(enumerate(chain(*query.paginate(per_page=200)))):
102
  records.append(record)
103
 
104
  # Calculate progress from 0 to 0.1
105
+ #achieved_progress = min(0.1, (i + 1) / query_length * 0.1)
106
 
107
  # Update progress bar
108
+ #progress(achieved_progress, desc="Getting queried data...")
109
 
110
 
111
 
 
129
 
130
 
131
 
132
+ #device = torch.device("mps" if torch.backends.mps.is_available() else "cuda")
133
+ #print(f"Using device: {device}")
134
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
135
  print(f"Using device: {device}")
136
 
137
  tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_aug2023refresh_base')
 
171
  all_embeddings.append(embeddings.cpu()) # Move to CPU to free GPU memory
172
  #torch.mps.empty_cache() # Clear cache to free up memory
173
  if count == 100:
174
+ #torch.mps.empty_cache()
175
+ torch.cuda.empty_cache()
176
  count = 0
177
 
178
  count +=1