Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
+Faster search -Higher memory
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
|
|
6 |
import plotly.graph_objs as go
|
7 |
import polars as pl
|
8 |
from datasets import concatenate_datasets, load_dataset
|
@@ -51,27 +52,29 @@ for subset in subsets:
|
|
51 |
"functional",
|
52 |
],
|
53 |
)
|
54 |
-
subsets_ds[subset] = dataset["train"]
|
55 |
|
56 |
-
|
57 |
-
# df = pd.concat([x.to_pandas() for x in datasets])
|
58 |
-
# train_df = dataset.to_pandas()
|
59 |
-
# del dataset
|
60 |
|
61 |
-
# dataset_element_combination_dict = {}
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
|
77 |
def create_phase_diagram(
|
@@ -85,36 +88,31 @@ def create_phase_diagram(
|
|
85 |
# Split elements and remove any whitespace
|
86 |
element_list = [el.strip() for el in elements.split("-")]
|
87 |
|
88 |
-
|
89 |
-
if functional == "PBE":
|
90 |
-
entries_df = subsets_ds["compatible_pbe"].to_pandas()
|
91 |
-
# entries_df = train_df[train_df["functional"] == "pbe"]
|
92 |
-
elif functional == "PBESol":
|
93 |
-
entries_df = subsets_ds["compatible_pbesol"].to_pandas()
|
94 |
-
# entries_df = train_df[train_df["functional"] == "pbesol"]
|
95 |
-
elif functional == "SCAN":
|
96 |
-
entries_df = subsets_ds["compatible_scan"].to_pandas()
|
97 |
-
# entries_df = train_df[train_df["functional"] == "scan"]
|
98 |
|
99 |
-
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
isintersection = lambda x: len(set(x).intersection(element_list)) > 0
|
105 |
-
entries_df = entries_df[
|
106 |
-
[isintersection(l) and isubset(l) for l in entries_df.elements.values.tolist()]
|
107 |
]
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
|
118 |
|
119 |
# Fetch all entries from the Materials Project database
|
120 |
def get_energy_correction(energy_correction, row):
|
@@ -153,6 +151,7 @@ def create_phase_diagram(
|
|
153 |
try:
|
154 |
phase_diagram = PhaseDiagram(entries)
|
155 |
except ValueError as e:
|
|
|
156 |
return go.Figure().add_annotation(text=str(e))
|
157 |
|
158 |
# Generate plotly figure
|
@@ -188,7 +187,10 @@ elements_input = gr.Textbox(
|
|
188 |
# minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
|
189 |
# )
|
190 |
energy_correction_dropdown = gr.Dropdown(
|
191 |
-
choices=[
|
|
|
|
|
|
|
192 |
label="Energy correction",
|
193 |
)
|
194 |
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
|
@@ -210,9 +212,7 @@ warning_message += " from <a href='https://chemrxiv.org/engage/api-gateway/chemr
|
|
210 |
message = '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">×</span>{}</div>Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format(
|
211 |
warning_message
|
212 |
)
|
213 |
-
message +=
|
214 |
-
"<br>Built with <a href='https://pymatgen.org/' target='_blank'>Pymatgen</a> and <a href='https://docs.crystaltoolkit.org/' target='_blank'>Crystal Toolkit</a>.<br>"
|
215 |
-
)
|
216 |
|
217 |
# Create Gradio interface
|
218 |
iface = gr.Interface(
|
|
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
+
import periodictable
|
7 |
import plotly.graph_objs as go
|
8 |
import polars as pl
|
9 |
from datasets import concatenate_datasets, load_dataset
|
|
|
52 |
"functional",
|
53 |
],
|
54 |
)
|
55 |
+
subsets_ds[subset] = dataset["train"].to_pandas()
|
56 |
|
57 |
+
elements_df = {k: subset["elements"] for k, subset in subsets_ds.items()}
|
|
|
|
|
|
|
58 |
|
|
|
59 |
|
60 |
+
all_elements = {str(el): i for i, el in enumerate(periodictable.elements)}
|
61 |
+
elements_indices = {}
|
62 |
+
for subset, df in elements_df.items():
|
63 |
+
print("Processing subset: ", subset)
|
64 |
+
elements_indices[subset] = np.zeros((len(df), len(all_elements)))
|
65 |
+
|
66 |
+
def map_elements(row):
|
67 |
+
index, xs = row["index"], row["elements"]
|
68 |
+
for x in xs:
|
69 |
+
elements_indices[subset][index, all_elements[x]] = 1
|
70 |
+
|
71 |
+
df = df.reset_index().apply(map_elements, axis=1)
|
72 |
+
|
73 |
+
map_functional = {
|
74 |
+
"PBE": "compatible_pbe",
|
75 |
+
"PBESol": "compatible_pbesol",
|
76 |
+
"SCAN": "compatible_scan",
|
77 |
+
}
|
78 |
|
79 |
|
80 |
def create_phase_diagram(
|
|
|
88 |
# Split elements and remove any whitespace
|
89 |
element_list = [el.strip() for el in elements.split("-")]
|
90 |
|
91 |
+
subset_name = map_functional[functional]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
element_list_vector = np.zeros(len(all_elements))
|
94 |
+
for el in element_list:
|
95 |
+
element_list_vector[all_elements[el]] = 1
|
96 |
|
97 |
+
n_elements = elements_indices[subset_name].sum(axis=1)
|
98 |
+
n_elements_query = elements_indices[subset_name][
|
99 |
+
:, element_list_vector.astype(bool)
|
|
|
|
|
|
|
100 |
]
|
101 |
|
102 |
+
if n_elements_query.shape[1] == 0:
|
103 |
+
indices_with_only_elements = []
|
104 |
+
else:
|
105 |
+
indices_with_only_elements = np.where(
|
106 |
+
n_elements_query.sum(axis=1) == n_elements
|
107 |
+
)[0]
|
108 |
+
|
109 |
+
print(indices_with_only_elements)
|
110 |
+
|
111 |
+
entries_df = subsets_ds[subset_name].loc[indices_with_only_elements]
|
112 |
+
|
113 |
+
entries_df = entries_df[~entries_df["immutable_id"].isna()]
|
114 |
|
115 |
+
print(entries_df)
|
116 |
|
117 |
# Fetch all entries from the Materials Project database
|
118 |
def get_energy_correction(energy_correction, row):
|
|
|
151 |
try:
|
152 |
phase_diagram = PhaseDiagram(entries)
|
153 |
except ValueError as e:
|
154 |
+
print(e)
|
155 |
return go.Figure().add_annotation(text=str(e))
|
156 |
|
157 |
# Generate plotly figure
|
|
|
187 |
# minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
|
188 |
# )
|
189 |
energy_correction_dropdown = gr.Dropdown(
|
190 |
+
choices=[
|
191 |
+
"The 110 PBE Method",
|
192 |
+
"Database specific, or MP2020",
|
193 |
+
],
|
194 |
label="Energy correction",
|
195 |
)
|
196 |
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
|
|
|
212 |
message = '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">×</span>{}</div>Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format(
|
213 |
warning_message
|
214 |
)
|
215 |
+
message += "<br>Built with <a href='https://pymatgen.org/' target='_blank'>Pymatgen</a> and <a href='https://docs.crystaltoolkit.org/' target='_blank'>Crystal Toolkit</a>.<br>"
|
|
|
|
|
216 |
|
217 |
# Create Gradio interface
|
218 |
iface = gr.Interface(
|