Ramlaoui commited on
Commit
a10ccb7
·
1 Parent(s): 3daf8bc

Use sparse matrices

Browse files
Files changed (2) hide show
  1. app.py +5 -1
  2. data_utils.py +24 -21
app.py CHANGED
@@ -18,7 +18,6 @@ from components import (
18
  get_upload_div,
19
  )
20
  from data_utils import (
21
- build_embeddings_index,
22
  build_formula_index,
23
  get_crystal_plot,
24
  get_dataset,
@@ -29,6 +28,11 @@ from data_utils import (
29
  EMPTY_DATA = False
30
  CACHE_PATH = None
31
 
 
 
 
 
 
32
  dataset = get_dataset()
33
 
34
  display_columns_query = [
 
18
  get_upload_div,
19
  )
20
  from data_utils import (
 
21
  build_formula_index,
22
  get_crystal_plot,
23
  get_dataset,
 
28
  EMPTY_DATA = False
29
  CACHE_PATH = None
30
 
31
+ if CACHE_PATH is not None:
32
+ import os
33
+
34
+ os.makedirs(CACHE_PATH, exist_ok=True)
35
+
36
  dataset = get_dataset()
37
 
38
  display_columns_query = [
data_utils.py CHANGED
@@ -72,6 +72,7 @@ mapping_table_idx_dataset_idx = {}
72
 
73
 
74
  def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False):
 
75
  if empty_data:
76
  return np.zeros((1, 1)), {}
77
 
@@ -80,40 +81,42 @@ def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=F
80
  use_dataset = dataset.select(index_range)
81
 
82
  # Preprocessing step to create an index for the dataset
83
- if cache_path is not None:
84
- train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb"))
85
 
86
- dataset_index = pickle.load(open(f"{cache_path}/dataset_index.pkl", "rb"))
 
 
87
  else:
88
  train_df = use_dataset.select_columns(
89
- ["chemical_formula_descriptive", "immutable_id"]
90
  ).to_pandas()
91
 
92
- pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
93
- extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
94
- extracted["count"] = extracted["count"].replace("", "1").astype(int)
95
-
96
- wide_df = (
97
- extracted.reset_index().pivot_table( # Move index to columns for pivoting
98
- index="level_0", # original row index
99
- columns="element",
100
- values="count",
101
- aggfunc="sum",
102
- fill_value=0,
103
- )
104
- )
105
 
106
- all_elements = [el.symbol for el in periodictable.elements] # full element list
107
- wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
 
 
108
 
109
- dataset_index = wide_df.values
 
 
110
 
111
  dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
112
  dataset_index = (
113
  dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
114
  ) # Normalize vectors
115
 
 
 
 
 
 
 
 
 
116
  immutable_id_to_idx = train_df["immutable_id"].to_dict()
 
117
  immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()}
118
 
119
  return dataset_index, immutable_id_to_idx
@@ -162,7 +165,7 @@ def search_materials(
162
  numb = int(numb) if numb else 1
163
  query_vector[map_periodic_table[el]] = numb
164
 
165
- similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector))
166
  indices = np.argsort(similarity)[::-1][:top_k]
167
 
168
  options = [dataset[int(i)] for i in indices]
 
72
 
73
 
74
  def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False):
75
+ print("Building formula index")
76
  if empty_data:
77
  return np.zeros((1, 1)), {}
78
 
 
81
  use_dataset = dataset.select(index_range)
82
 
83
  # Preprocessing step to create an index for the dataset
84
+ from scipy.sparse import load_npz
 
85
 
86
+ if cache_path is not None and os.path.exists(f"{cache_path}/train_df.pkl"):
87
+ train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb"))
88
+ dataset_index = load_npz(f"{cache_path}/dataset_index.npz")
89
  else:
90
  train_df = use_dataset.select_columns(
91
+ ["species_at_sites", "immutable_id", "functional"]
92
  ).to_pandas()
93
 
94
+ import tqdm
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ all_elements = {
97
+ str(el.symbol): i for i, el in enumerate(periodictable.elements)
98
+ } # full element list
99
+ dataset_index = np.zeros((len(train_df), len(all_elements)))
100
 
101
+ for idx, species in tqdm.tqdm(enumerate(train_df["species_at_sites"].values)):
102
+ for el in species:
103
+ dataset_index[idx, all_elements[el]] += 1
104
 
105
  dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
106
  dataset_index = (
107
  dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
108
  ) # Normalize vectors
109
 
110
+ from scipy.sparse import csr_matrix, save_npz
111
+
112
+ dataset_index = csr_matrix(dataset_index)
113
+
114
+ if cache_path is not None:
115
+ pickle.dump(train_df, open(f"{cache_path}/train_df.pkl", "wb"))
116
+ save_npz(f"{cache_path}/dataset_index.npz", dataset_index)
117
+
118
  immutable_id_to_idx = train_df["immutable_id"].to_dict()
119
+ del train_df
120
  immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()}
121
 
122
  return dataset_index, immutable_id_to_idx
 
165
  numb = int(numb) if numb else 1
166
  query_vector[map_periodic_table[el]] = numb
167
 
168
+ similarity = dataset_index.dot(query_vector) / (np.linalg.norm(query_vector))
169
  indices = np.argsort(similarity)[::-1][:top_k]
170
 
171
  options = [dataset[int(i)] for i in indices]