import os import gradio as gr import numpy as np import pandas as pd import periodictable import plotly.graph_objs as go import polars as pl from datasets import concatenate_datasets, load_dataset from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram from pymatgen.core import Composition, Element, Structure from pymatgen.core.composition import Composition from pymatgen.entries.computed_entries import ( ComputedStructureEntry, GibbsComputedStructureEntry, ) HF_TOKEN = os.environ.get("HF_TOKEN") subsets = [ "compatible_pbe", "compatible_pbesol", "compatible_scan", ] # polars_dfs = { # subset: pl.read_parquet( # "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset), # storage_options={ # "token": HF_TOKEN, # }, # ) # for subset in subsets # } # # Load only the train split of the dataset subsets_ds = {} for subset in subsets: dataset = load_dataset( "LeMaterial/leMat1", subset, token=HF_TOKEN, columns=[ "lattice_vectors", "species_at_sites", "cartesian_site_positions", "energy", "energy_corrected", "immutable_id", "elements", "functional", ], ) subsets_ds[subset] = dataset["train"].to_pandas() elements_df = {k: subset["elements"] for k, subset in subsets_ds.items()} all_elements = {str(el): i for i, el in enumerate(periodictable.elements)} elements_indices = {} for subset, df in elements_df.items(): print("Processing subset: ", subset) elements_indices[subset] = np.zeros((len(df), len(all_elements))) def map_elements(row): index, xs = row["index"], row["elements"] for x in xs: elements_indices[subset][index, all_elements[x]] = 1 df = df.reset_index().apply(map_elements, axis=1) map_functional = { "PBE": "compatible_pbe", "PBESol": "compatible_pbesol", "SCAN": "compatible_scan", } def create_phase_diagram( elements, energy_correction, plot_style, functional, finite_temp, **kwargs, ): # Split elements and remove any whitespace element_list = [el.strip() for el in elements.split("-")] subset_name = map_functional[functional] element_list_vector = np.zeros(len(all_elements)) for el in element_list: element_list_vector[all_elements[el]] = 1 n_elements = elements_indices[subset_name].sum(axis=1) n_elements_query = elements_indices[subset_name][ :, element_list_vector.astype(bool) ] if n_elements_query.shape[1] == 0: indices_with_only_elements = [] else: indices_with_only_elements = np.where( n_elements_query.sum(axis=1) == n_elements )[0] print(indices_with_only_elements) entries_df = subsets_ds[subset_name].loc[indices_with_only_elements] entries_df = entries_df[~entries_df["immutable_id"].isna()] print(entries_df) # Fetch all entries from the Materials Project database def get_energy_correction(energy_correction, row): if energy_correction == "Database specific, or MP2020": return ( row["energy_corrected"] - row["energy"] if not np.isnan(row["energy_corrected"]) else 0 ) elif energy_correction == "The 110 PBE Method": return row["energy"] * 1.1 entries = [ ComputedStructureEntry( Structure( [x.tolist() for x in row["lattice_vectors"].tolist()], row["species_at_sites"], row["cartesian_site_positions"], coords_are_cartesian=True, ), energy=row["energy"], correction=get_energy_correction(energy_correction, row), entry_id=row["immutable_id"], parameters={"run_type": row["functional"]}, ) for n, row in entries_df.iterrows() ] # TODO: Fetch elemental entries (they are usually GGA calculations) # entries.extend([e for e in entries if e.composition.is_element]) if finite_temp: entries = GibbsComputedStructureEntry.from_entries(entries) # Build the phase diagram try: phase_diagram = PhaseDiagram(entries) except ValueError as e: print(e) return go.Figure().add_annotation(text=str(e)) # Generate plotly figure if plot_style == "2D": plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly") fig = plotter.get_plot() else: # For 3D plots, limit to ternary systems if len(element_list) == 3: plotter = PDPlotter( phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d" ) fig = plotter.get_plot() else: return go.Figure().add_annotation( text="3D plots are only available for ternary systems." ) # Adjust the maximum energy above hull # (This is a placeholder as PDPlotter does not support direct filtering) # Return the figure return fig # Define Gradio interface components elements_input = gr.Textbox( label="Elements (e.g., 'Li-Fe-O')", placeholder="Enter elements separated by '-'", value="Li-Fe-O", ) # max_e_above_hull_slider = gr.Slider( # minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)" # ) energy_correction_dropdown = gr.Dropdown( choices=[ "The 110 PBE Method", "Database specific, or MP2020", ], label="Energy correction", ) plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style") functional_dropdown = gr.Dropdown(choices=["PBE", "PBESol", "SCAN"], label="Functional") finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation") warning_message = "This application uses energy correction schemes directly" warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP" warning_message += " Compatibility scheme applied to OQMD. However, because we did" warning_message += " not directly apply the compatibility schemes to Alexandria, MP" warning_message += " we have noticed discrepencies in the data. While the correction" warning_message += " scheme will be standardized in a soon to be released update, for" warning_message += " now please take caution when analyzing the results of this" warning_message += " application." warning_message += "
Additionally, we have provided the 110 PBE correction method" warning_message += " from Rohr et al (2024).
" message = '
×{}
Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format( warning_message ) message += "
Built with Pymatgen and Crystal Toolkit.
" # Create Gradio interface iface = gr.Interface( fn=create_phase_diagram, inputs=[ elements_input, # max_e_above_hull_slider, energy_correction_dropdown, plot_style_dropdown, functional_dropdown, finite_temp_toggle, ], outputs=gr.Plot(label="Phase Diagram"), title="LeMaterial - Phase Diagram Viewer", description=message, ) # Launch the app iface.launch()