Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import periodictable | |
import plotly.graph_objs as go | |
import polars as pl | |
from datasets import concatenate_datasets, load_dataset | |
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram | |
from pymatgen.core import Composition, Element, Structure | |
from pymatgen.core.composition import Composition | |
from pymatgen.entries.computed_entries import ( | |
ComputedStructureEntry, | |
GibbsComputedStructureEntry, | |
) | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
subsets = [ | |
"compatible_pbe", | |
"compatible_pbesol", | |
"compatible_scan", | |
] | |
# polars_dfs = { | |
# subset: pl.read_parquet( | |
# "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset), | |
# storage_options={ | |
# "token": HF_TOKEN, | |
# }, | |
# ) | |
# for subset in subsets | |
# } | |
# # Load only the train split of the dataset | |
subsets_ds = {} | |
for subset in subsets: | |
dataset = load_dataset( | |
"LeMaterial/leMat1", | |
subset, | |
token=HF_TOKEN, | |
columns=[ | |
"lattice_vectors", | |
"species_at_sites", | |
"cartesian_site_positions", | |
"energy", | |
"energy_corrected", | |
"immutable_id", | |
"elements", | |
"functional", | |
], | |
) | |
subsets_ds[subset] = dataset["train"].to_pandas() | |
elements_df = {k: subset["elements"] for k, subset in subsets_ds.items()} | |
all_elements = {str(el): i for i, el in enumerate(periodictable.elements)} | |
elements_indices = {} | |
for subset, df in elements_df.items(): | |
print("Processing subset: ", subset) | |
elements_indices[subset] = np.zeros((len(df), len(all_elements))) | |
def map_elements(row): | |
index, xs = row["index"], row["elements"] | |
for x in xs: | |
elements_indices[subset][index, all_elements[x]] = 1 | |
df = df.reset_index().apply(map_elements, axis=1) | |
map_functional = { | |
"PBE": "compatible_pbe", | |
"PBESol": "compatible_pbesol", | |
"SCAN": "compatible_scan", | |
} | |
def create_phase_diagram( | |
elements, | |
energy_correction, | |
plot_style, | |
functional, | |
finite_temp, | |
**kwargs, | |
): | |
# Split elements and remove any whitespace | |
element_list = [el.strip() for el in elements.split("-")] | |
subset_name = map_functional[functional] | |
element_list_vector = np.zeros(len(all_elements)) | |
for el in element_list: | |
element_list_vector[all_elements[el]] = 1 | |
n_elements = elements_indices[subset_name].sum(axis=1) | |
n_elements_query = elements_indices[subset_name][ | |
:, element_list_vector.astype(bool) | |
] | |
if n_elements_query.shape[1] == 0: | |
indices_with_only_elements = [] | |
else: | |
indices_with_only_elements = np.where( | |
n_elements_query.sum(axis=1) == n_elements | |
)[0] | |
print(indices_with_only_elements) | |
entries_df = subsets_ds[subset_name].loc[indices_with_only_elements] | |
entries_df = entries_df[~entries_df["immutable_id"].isna()] | |
print(entries_df) | |
# Fetch all entries from the Materials Project database | |
def get_energy_correction(energy_correction, row): | |
if energy_correction == "Database specific, or MP2020": | |
return ( | |
row["energy_corrected"] - row["energy"] | |
if not np.isnan(row["energy_corrected"]) | |
else 0 | |
) | |
elif energy_correction == "The 110 PBE Method": | |
return row["energy"] * 1.1 | |
entries = [ | |
ComputedStructureEntry( | |
Structure( | |
[x.tolist() for x in row["lattice_vectors"].tolist()], | |
row["species_at_sites"], | |
row["cartesian_site_positions"], | |
coords_are_cartesian=True, | |
), | |
energy=row["energy"], | |
correction=get_energy_correction(energy_correction, row), | |
entry_id=row["immutable_id"], | |
parameters={"run_type": row["functional"]}, | |
) | |
for n, row in entries_df.iterrows() | |
] | |
# TODO: Fetch elemental entries (they are usually GGA calculations) | |
# entries.extend([e for e in entries if e.composition.is_element]) | |
if finite_temp: | |
entries = GibbsComputedStructureEntry.from_entries(entries) | |
# Build the phase diagram | |
try: | |
phase_diagram = PhaseDiagram(entries) | |
except ValueError as e: | |
print(e) | |
return go.Figure().add_annotation(text=str(e)) | |
# Generate plotly figure | |
if plot_style == "2D": | |
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly") | |
fig = plotter.get_plot() | |
else: | |
# For 3D plots, limit to ternary systems | |
if len(element_list) == 3: | |
plotter = PDPlotter( | |
phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d" | |
) | |
fig = plotter.get_plot() | |
else: | |
return go.Figure().add_annotation( | |
text="3D plots are only available for ternary systems." | |
) | |
# Adjust the maximum energy above hull | |
# (This is a placeholder as PDPlotter does not support direct filtering) | |
# Return the figure | |
return fig | |
# Define Gradio interface components | |
elements_input = gr.Textbox( | |
label="Elements (e.g., 'Li-Fe-O')", | |
placeholder="Enter elements separated by '-'", | |
value="Li-Fe-O", | |
) | |
# max_e_above_hull_slider = gr.Slider( | |
# minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)" | |
# ) | |
energy_correction_dropdown = gr.Dropdown( | |
choices=[ | |
"The 110 PBE Method", | |
"Database specific, or MP2020", | |
], | |
label="Energy correction", | |
) | |
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style") | |
functional_dropdown = gr.Dropdown(choices=["PBE", "PBESol", "SCAN"], label="Functional") | |
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation") | |
warning_message = "This application uses energy correction schemes directly" | |
warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP" | |
warning_message += " Compatibility scheme applied to OQMD. However, because we did" | |
warning_message += " not directly apply the compatibility schemes to Alexandria, MP" | |
warning_message += " we have noticed discrepencies in the data. While the correction" | |
warning_message += " scheme will be standardized in a soon to be released update, for" | |
warning_message += " now please take caution when analyzing the results of this" | |
warning_message += " application." | |
warning_message += "<br> Additionally, we have provided the 110 PBE correction method" | |
warning_message += " from <a href='https://chemrxiv.org/engage/api-gateway/chemrxiv/assets/orp/resource/item/67252d617be152b1d0b2c1ef/original/a-simple-linear-relation-solves-unphysical-dft-energy-corrections.pdf' target='_blank'>Rohr et al (2024)</a>.<br>" | |
message = '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">×</span>{}</div>Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format( | |
warning_message | |
) | |
message += "<br>Built with <a href='https://pymatgen.org/' target='_blank'>Pymatgen</a> and <a href='https://docs.crystaltoolkit.org/' target='_blank'>Crystal Toolkit</a>.<br>" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=create_phase_diagram, | |
inputs=[ | |
elements_input, | |
# max_e_above_hull_slider, | |
energy_correction_dropdown, | |
plot_style_dropdown, | |
functional_dropdown, | |
finite_temp_toggle, | |
], | |
outputs=gr.Plot(label="Phase Diagram"), | |
title="LeMaterial - Phase Diagram Viewer", | |
description=message, | |
) | |
# Launch the app | |
iface.launch() | |