Ramlaoui commited on
Commit
739c614
·
1 Parent(s): 1517af1

+Faster search -Higher memory

Browse files
Files changed (1) hide show
  1. app.py +47 -47
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import gradio as gr
4
  import numpy as np
5
  import pandas as pd
 
6
  import plotly.graph_objs as go
7
  import polars as pl
8
  from datasets import concatenate_datasets, load_dataset
@@ -51,27 +52,29 @@ for subset in subsets:
51
  "functional",
52
  ],
53
  )
54
- subsets_ds[subset] = dataset["train"]
55
 
56
- # Convert the train split to a pandas DataFrame
57
- # df = pd.concat([x.to_pandas() for x in datasets])
58
- # train_df = dataset.to_pandas()
59
- # del dataset
60
 
61
- # dataset_element_combination_dict = {}
62
 
63
- # isubset = lambda x: set(x).issubset(element_list)
64
- # isintersection = lambda x: len(set(x).intersection(element_list)) > 0
65
- # for element_1 in Element:
66
- # for element_2 in Element:
67
- # for element_3 in Element:
68
- # if element_1 != element_2 and element_2 != element_3 and element_3 != element_1:
69
- # print("processing {},{},{}".format(*element_list))
70
- # element_list = [element_1.name, element_2.name, element_3.name]
71
- # dataset_element_combination_dict(sorted(tuple(element_list))) = dataset.filter(
72
- # lambda example: isintersection(example["elements"])
73
- # and isubset(example["elements"])
74
- # )
 
 
 
 
 
 
75
 
76
 
77
  def create_phase_diagram(
@@ -85,36 +88,31 @@ def create_phase_diagram(
85
  # Split elements and remove any whitespace
86
  element_list = [el.strip() for el in elements.split("-")]
87
 
88
- # Filter entries based on functional
89
- if functional == "PBE":
90
- entries_df = subsets_ds["compatible_pbe"].to_pandas()
91
- # entries_df = train_df[train_df["functional"] == "pbe"]
92
- elif functional == "PBESol":
93
- entries_df = subsets_ds["compatible_pbesol"].to_pandas()
94
- # entries_df = train_df[train_df["functional"] == "pbesol"]
95
- elif functional == "SCAN":
96
- entries_df = subsets_ds["compatible_scan"].to_pandas()
97
- # entries_df = train_df[train_df["functional"] == "scan"]
98
 
99
- # entries_df = df.to_pandas()
 
 
100
 
101
- entries_df = entries_df[~entries_df["immutable_id"].isna()]
102
-
103
- isubset = lambda x: set(x).issubset(element_list)
104
- isintersection = lambda x: len(set(x).intersection(element_list)) > 0
105
- entries_df = entries_df[
106
- [isintersection(l) and isubset(l) for l in entries_df.elements.values.tolist()]
107
  ]
108
 
109
- # df = df.filter((df.col("elements").list.contains(x) for x in element_list))
110
- # df = df.filter(
111
- # pl.col("elements")
112
- # .list.eval(pl.element().is_in(element_list))
113
- # .list.any()
114
- # .alias("check")
115
- # )
 
 
 
 
 
116
 
117
- # entries_df = df.to_pandas()
118
 
119
  # Fetch all entries from the Materials Project database
120
  def get_energy_correction(energy_correction, row):
@@ -153,6 +151,7 @@ def create_phase_diagram(
153
  try:
154
  phase_diagram = PhaseDiagram(entries)
155
  except ValueError as e:
 
156
  return go.Figure().add_annotation(text=str(e))
157
 
158
  # Generate plotly figure
@@ -188,7 +187,10 @@ elements_input = gr.Textbox(
188
  # minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
189
  # )
190
  energy_correction_dropdown = gr.Dropdown(
191
- choices=["The 110 PBE Method", "Database specific, or MP2020",],
 
 
 
192
  label="Energy correction",
193
  )
194
  plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
@@ -210,9 +212,7 @@ warning_message += " from <a href='https://chemrxiv.org/engage/api-gateway/chemr
210
  message = '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">&times;</span>{}</div>Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format(
211
  warning_message
212
  )
213
- message += (
214
- "<br>Built with <a href='https://pymatgen.org/' target='_blank'>Pymatgen</a> and <a href='https://docs.crystaltoolkit.org/' target='_blank'>Crystal Toolkit</a>.<br>"
215
- )
216
 
217
  # Create Gradio interface
218
  iface = gr.Interface(
 
3
  import gradio as gr
4
  import numpy as np
5
  import pandas as pd
6
+ import periodictable
7
  import plotly.graph_objs as go
8
  import polars as pl
9
  from datasets import concatenate_datasets, load_dataset
 
52
  "functional",
53
  ],
54
  )
55
+ subsets_ds[subset] = dataset["train"].to_pandas()
56
 
57
+ elements_df = {k: subset["elements"] for k, subset in subsets_ds.items()}
 
 
 
58
 
 
59
 
60
+ all_elements = {str(el): i for i, el in enumerate(periodictable.elements)}
61
+ elements_indices = {}
62
+ for subset, df in elements_df.items():
63
+ print("Processing subset: ", subset)
64
+ elements_indices[subset] = np.zeros((len(df), len(all_elements)))
65
+
66
+ def map_elements(row):
67
+ index, xs = row["index"], row["elements"]
68
+ for x in xs:
69
+ elements_indices[subset][index, all_elements[x]] = 1
70
+
71
+ df = df.reset_index().apply(map_elements, axis=1)
72
+
73
+ map_functional = {
74
+ "PBE": "compatible_pbe",
75
+ "PBESol": "compatible_pbesol",
76
+ "SCAN": "compatible_scan",
77
+ }
78
 
79
 
80
  def create_phase_diagram(
 
88
  # Split elements and remove any whitespace
89
  element_list = [el.strip() for el in elements.split("-")]
90
 
91
+ subset_name = map_functional[functional]
 
 
 
 
 
 
 
 
 
92
 
93
+ element_list_vector = np.zeros(len(all_elements))
94
+ for el in element_list:
95
+ element_list_vector[all_elements[el]] = 1
96
 
97
+ n_elements = elements_indices[subset_name].sum(axis=1)
98
+ n_elements_query = elements_indices[subset_name][
99
+ :, element_list_vector.astype(bool)
 
 
 
100
  ]
101
 
102
+ if n_elements_query.shape[1] == 0:
103
+ indices_with_only_elements = []
104
+ else:
105
+ indices_with_only_elements = np.where(
106
+ n_elements_query.sum(axis=1) == n_elements
107
+ )[0]
108
+
109
+ print(indices_with_only_elements)
110
+
111
+ entries_df = subsets_ds[subset_name].loc[indices_with_only_elements]
112
+
113
+ entries_df = entries_df[~entries_df["immutable_id"].isna()]
114
 
115
+ print(entries_df)
116
 
117
  # Fetch all entries from the Materials Project database
118
  def get_energy_correction(energy_correction, row):
 
151
  try:
152
  phase_diagram = PhaseDiagram(entries)
153
  except ValueError as e:
154
+ print(e)
155
  return go.Figure().add_annotation(text=str(e))
156
 
157
  # Generate plotly figure
 
187
  # minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
188
  # )
189
  energy_correction_dropdown = gr.Dropdown(
190
+ choices=[
191
+ "The 110 PBE Method",
192
+ "Database specific, or MP2020",
193
+ ],
194
  label="Energy correction",
195
  )
196
  plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
 
212
  message = '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">&times;</span>{}</div>Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format(
213
  warning_message
214
  )
215
+ message += "<br>Built with <a href='https://pymatgen.org/' target='_blank'>Pymatgen</a> and <a href='https://docs.crystaltoolkit.org/' target='_blank'>Crystal Toolkit</a>.<br>"
 
 
216
 
217
  # Create Gradio interface
218
  iface = gr.Interface(