Spaces:
Running
Running
update ptable, add statbility, update pyproject.toml
Browse files- mlip_arena/models/registry.yaml +4 -0
- mlip_arena/tasks/stability/run.py +1 -1
- pyproject.toml +8 -2
- serve/app.py +3 -1
- serve/tasks/homonuclear-diatomics.py +2 -0
- serve/tasks/stability.py +28 -0
- serve/tools/ptable.py +77 -92
mlip_arena/models/registry.yaml
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
MACE-MP(M):
|
4 |
module: externals
|
5 |
class: MACE_MP_Medium
|
|
|
6 |
username: cyrusyc # HF username
|
7 |
last-update: 2024-03-25T14:30:00
|
8 |
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
|
@@ -20,6 +21,7 @@ MACE-MP(M):
|
|
20 |
CHGNet:
|
21 |
module: externals
|
22 |
class: CHGNet
|
|
|
23 |
username: cyrusyc
|
24 |
last-update: 2024-07-08T00:00:00
|
25 |
datetime: 2024-07-08T00:00:00
|
@@ -31,6 +33,7 @@ CHGNet:
|
|
31 |
EquiformerV2(OC22):
|
32 |
module: externals
|
33 |
class: EquiformerV2
|
|
|
34 |
username: cyrusyc
|
35 |
last-update: 2024-07-08T00:00:00
|
36 |
datetime: 2024-07-08T00:00:00
|
@@ -42,6 +45,7 @@ EquiformerV2(OC22):
|
|
42 |
eSCN(OC20):
|
43 |
module: externals
|
44 |
class: eSCN
|
|
|
45 |
username: cyrusyc
|
46 |
last-update: 2024-07-08T00:00:00
|
47 |
datetime: 2024-07-08T00:00:00
|
|
|
3 |
MACE-MP(M):
|
4 |
module: externals
|
5 |
class: MACE_MP_Medium
|
6 |
+
family: mace
|
7 |
username: cyrusyc # HF username
|
8 |
last-update: 2024-03-25T14:30:00
|
9 |
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
|
|
|
21 |
CHGNet:
|
22 |
module: externals
|
23 |
class: CHGNet
|
24 |
+
family: chgnet
|
25 |
username: cyrusyc
|
26 |
last-update: 2024-07-08T00:00:00
|
27 |
datetime: 2024-07-08T00:00:00
|
|
|
33 |
EquiformerV2(OC22):
|
34 |
module: externals
|
35 |
class: EquiformerV2
|
36 |
+
family: equiformer
|
37 |
username: cyrusyc
|
38 |
last-update: 2024-07-08T00:00:00
|
39 |
datetime: 2024-07-08T00:00:00
|
|
|
45 |
eSCN(OC20):
|
46 |
module: externals
|
47 |
class: eSCN
|
48 |
+
family: escn
|
49 |
username: cyrusyc
|
50 |
last-update: 2024-07-08T00:00:00
|
51 |
datetime: 2024-07-08T00:00:00
|
mlip_arena/tasks/stability/run.py
CHANGED
@@ -119,7 +119,7 @@ def _get_ensemble_defaults(
|
|
119 |
ase_md_kwargs.pop("externalstress", None)
|
120 |
elif ensemble == "npt":
|
121 |
ase_md_kwargs["temperature_K"] = t_schedule[0]
|
122 |
-
ase_md_kwargs["externalstress"] = p_schedule[0] * 1e3 * units.bar
|
123 |
|
124 |
if isinstance(dynamics, str) and dynamics.lower() == "langevin":
|
125 |
ase_md_kwargs["friction"] = ase_md_kwargs.get(
|
|
|
119 |
ase_md_kwargs.pop("externalstress", None)
|
120 |
elif ensemble == "npt":
|
121 |
ase_md_kwargs["temperature_K"] = t_schedule[0]
|
122 |
+
ase_md_kwargs["externalstress"] = p_schedule[0] # * 1e3 * units.bar
|
123 |
|
124 |
if isinstance(dynamics, str) and dynamics.lower() == "langevin":
|
125 |
ase_md_kwargs["friction"] = ase_md_kwargs.get(
|
pyproject.toml
CHANGED
@@ -26,12 +26,13 @@ classifiers=[
|
|
26 |
"Programming Language :: Python :: 3 :: Only",
|
27 |
]
|
28 |
dependencies=[
|
29 |
-
"torch",
|
30 |
"ase",
|
|
|
31 |
"torch_dftd>=0.4.0",
|
32 |
"huggingface_hub",
|
33 |
"torch-geometric",
|
34 |
-
"safetensors"
|
|
|
35 |
]
|
36 |
|
37 |
[project.optional-dependencies]
|
@@ -39,6 +40,11 @@ m3gnet = ["matgl", "dgl", "torch<=2.2.1"]
|
|
39 |
mace = ["mace-torch"]
|
40 |
chgnet = ["chgnet"]
|
41 |
fairchem = ["fairchem"]
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
[project.urls]
|
44 |
Homepage = "https://github.com/atomind-ai/mlip-arena"
|
|
|
26 |
"Programming Language :: Python :: 3 :: Only",
|
27 |
]
|
28 |
dependencies=[
|
|
|
29 |
"ase",
|
30 |
+
"torch",
|
31 |
"torch_dftd>=0.4.0",
|
32 |
"huggingface_hub",
|
33 |
"torch-geometric",
|
34 |
+
"safetensors",
|
35 |
+
"pymatgen"
|
36 |
]
|
37 |
|
38 |
[project.optional-dependencies]
|
|
|
40 |
mace = ["mace-torch"]
|
41 |
chgnet = ["chgnet"]
|
42 |
fairchem = ["fairchem"]
|
43 |
+
app = [
|
44 |
+
"streamlit",
|
45 |
+
"plotly",
|
46 |
+
"bokeh==2.4.3",
|
47 |
+
]
|
48 |
|
49 |
[project.urls]
|
50 |
Homepage = "https://github.com/atomind-ai/mlip-arena"
|
serve/app.py
CHANGED
@@ -40,6 +40,8 @@ history = st.Page("tools/history.py", title="History", icon=":material/history:"
|
|
40 |
ptable = st.Page("tools/ptable.py", title="Periodic table", icon=":material/gradient:")
|
41 |
|
42 |
diatomics = st.Page("tasks/homonuclear-diatomics.py", title="Homonuclear diatomics", icon=":material/target:", default=True)
|
|
|
|
|
43 |
|
44 |
# if st.session_state.logged_in:
|
45 |
pg = st.navigation(
|
@@ -48,7 +50,7 @@ pg = st.navigation(
|
|
48 |
# "Reports": [dashboard, bugs, alerts],
|
49 |
# "Tools": [search, history, ptable],
|
50 |
"": [leaderboard],
|
51 |
-
"Tasks": [diatomics],
|
52 |
"Tools": [ptable],
|
53 |
}
|
54 |
)
|
|
|
40 |
ptable = st.Page("tools/ptable.py", title="Periodic table", icon=":material/gradient:")
|
41 |
|
42 |
diatomics = st.Page("tasks/homonuclear-diatomics.py", title="Homonuclear diatomics", icon=":material/target:", default=True)
|
43 |
+
stability = st.Page("tasks/stability.py", title="Stability", icon=":material/target:")
|
44 |
+
|
45 |
|
46 |
# if st.session_state.logged_in:
|
47 |
pg = st.navigation(
|
|
|
50 |
# "Reports": [dashboard, bugs, alerts],
|
51 |
# "Tools": [search, history, ptable],
|
52 |
"": [leaderboard],
|
53 |
+
"Tasks": [diatomics, stability],
|
54 |
"Tools": [ptable],
|
55 |
}
|
56 |
)
|
serve/tasks/homonuclear-diatomics.py
CHANGED
@@ -9,6 +9,8 @@ from ase.data import chemical_symbols
|
|
9 |
from plotly.subplots import make_subplots
|
10 |
from scipy.interpolate import CubicSpline
|
11 |
|
|
|
|
|
12 |
st.markdown("# Homonuclear diatomics")
|
13 |
|
14 |
st.markdown("### Methods")
|
|
|
9 |
from plotly.subplots import make_subplots
|
10 |
from scipy.interpolate import CubicSpline
|
11 |
|
12 |
+
from mlip_arena.models.utils import MLIPMap
|
13 |
+
|
14 |
st.markdown("# Homonuclear diatomics")
|
15 |
|
16 |
st.markdown("### Methods")
|
serve/tasks/stability.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import plotly.colors as pcolors
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
import streamlit as st
|
8 |
+
from ase.data import chemical_symbols
|
9 |
+
from ase.io import read, write
|
10 |
+
from plotly.subplots import make_subplots
|
11 |
+
from scipy.interpolate import CubicSpline
|
12 |
+
|
13 |
+
from mlip_arena.models.utils import MLIPMap
|
14 |
+
|
15 |
+
st.markdown("# Stability")
|
16 |
+
|
17 |
+
st.markdown("### Methods")
|
18 |
+
container = st.container(border=True)
|
19 |
+
methods = container.multiselect("MLIPs", ["MACE-MP", "Equiformer", "CHGNet", "MACE-OFF", "eSCN", "ALIGNN"], ["MACE-MP", "Equiformer", "CHGNet", "eSCN", "ALIGNN"])
|
20 |
+
|
21 |
+
|
22 |
+
DATA_DIR = Path("mlip_arena/tasks/stability")
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
serve/tools/ptable.py
CHANGED
@@ -1,93 +1,78 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from ase.data import chemical_symbols
|
3 |
-
from pymatgen.core import Element
|
4 |
-
|
5 |
-
elements = [Element.from_Z(z) for z in range(1, 119)]
|
6 |
-
|
7 |
-
# Define the number of rows and columns in the periodic table
|
8 |
-
rows = 9 # There are 7 rows in the conventional periodic table
|
9 |
-
columns = 18
|
10 |
-
|
11 |
-
# Define a function to display the periodic table
|
12 |
-
def display_periodic_table():
|
13 |
-
# elements = [
|
14 |
-
# (element, element) for element in chemical_symbols[1:]
|
15 |
-
# ]
|
16 |
-
|
17 |
-
# cols = st.columns(18, gap='small', vertical_alignment='bottom') # Create 18 columns for the periodic table layout
|
18 |
-
|
19 |
-
row = 0
|
20 |
-
for element in elements:
|
21 |
-
symbol = element.symbol
|
22 |
-
atomic_number = element.Z
|
23 |
-
group = element.group
|
24 |
-
|
25 |
-
if element.row > row:
|
26 |
-
cols = st.columns(columns, gap='small', vertical_alignment='bottom')
|
27 |
-
row = element.row
|
28 |
-
|
29 |
-
if element.block == 'f':
|
30 |
-
continue
|
31 |
-
|
32 |
-
with cols[group - 1]:
|
33 |
-
if st.button(symbol, use_container_width=True):
|
34 |
-
st.session_state.selected_element = symbol
|
35 |
-
st.session_state.selected_name = symbol
|
36 |
-
st.rerun()
|
37 |
-
# st.experimental_rerun()
|
38 |
-
|
39 |
-
for element in elements:
|
40 |
-
symbol = element.symbol
|
41 |
-
atomic_number = element.Z
|
42 |
-
group = element.group
|
43 |
-
|
44 |
-
if element.row > row:
|
45 |
-
cols = st.columns(columns, gap='small', vertical_alignment='bottom')
|
46 |
-
row = element.row
|
47 |
-
|
48 |
-
if element.block == 'f':
|
49 |
-
noble = Element.from_row_and_group(row-1, 18)
|
50 |
-
row += 2
|
51 |
-
group += atomic_number - noble.Z - 2
|
52 |
-
else:
|
53 |
-
continue
|
54 |
-
|
55 |
-
with cols[group - 1]:
|
56 |
-
if st.button(symbol, use_container_width=True):
|
57 |
-
st.session_state.selected_element = symbol
|
58 |
-
st.session_state.selected_name = symbol
|
59 |
-
st.rerun()
|
60 |
-
# st.experimental_rerun()
|
61 |
-
|
62 |
-
|
63 |
-
# for idx, (symbol, name) in enumerate(elements):
|
64 |
-
# with cols[idx % 18]: # Place each element in the correct column
|
65 |
-
# if st.button(symbol, use_container_width=True):
|
66 |
-
# st.session_state.selected_element = symbol
|
67 |
-
# st.session_state.selected_name = name
|
68 |
-
# st.experimental_rerun()
|
69 |
-
|
70 |
-
# Define a function to display the details of an element
|
71 |
-
def display_element_details():
|
72 |
-
symbol = st.session_state.selected_element
|
73 |
-
name = st.session_state.selected_name
|
74 |
-
st.write(f"### {name} ({symbol})")
|
75 |
-
st.write(f"Details about {name} ({symbol}) will be displayed here.")
|
76 |
-
if st.button("Back to Periodic Table"):
|
77 |
-
st.session_state.selected_element = None
|
78 |
-
st.session_state.selected_name = None
|
79 |
-
st.rerun()
|
80 |
-
# st.experimental_rerun()
|
81 |
-
|
82 |
-
|
83 |
-
st.title("Periodic Table")
|
84 |
-
|
85 |
-
# st.balloons()
|
86 |
-
if 'selected_element' not in st.session_state:
|
87 |
-
st.session_state.selected_element = None
|
88 |
-
|
89 |
-
if st.session_state.selected_element:
|
90 |
-
display_element_details()
|
91 |
-
else:
|
92 |
-
display_periodic_table()
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
+
|
3 |
+
# NOTE: https://stackoverflow.com/questions/77062368/streamlit-bokeh-event-callback-to-get-clicked-values
|
4 |
+
# Taptool: https://docs.bokeh.org/en/2.4.2/docs/reference/models/tools.html#taptool
|
5 |
+
|
6 |
+
import streamlit as st
|
7 |
+
from bokeh.plotting import figure
|
8 |
+
from bokeh.plotting import figure, show
|
9 |
+
from bokeh.sampledata.periodic_table import elements
|
10 |
+
from bokeh.transform import dodge, factor_cmap
|
11 |
+
|
12 |
+
periods = ["I", "II", "III", "IV", "V", "VI", "VII"]
|
13 |
+
groups = [str(x) for x in range(1, 19)]
|
14 |
+
|
15 |
+
df = elements.copy()
|
16 |
+
df["atomic mass"] = df["atomic mass"].astype(str)
|
17 |
+
df["group"] = df["group"].astype(str)
|
18 |
+
df["period"] = [periods[x-1] for x in df.period]
|
19 |
+
df = df[df.group != "-"]
|
20 |
+
df = df[df.symbol != "Lr"]
|
21 |
+
df = df[df.symbol != "Lu"]
|
22 |
+
|
23 |
+
cmap = {
|
24 |
+
"alkali metal" : "#a6cee3",
|
25 |
+
"alkaline earth metal" : "#1f78b4",
|
26 |
+
"metal" : "#d93b43",
|
27 |
+
"halogen" : "#999d9a",
|
28 |
+
"metalloid" : "#e08d49",
|
29 |
+
"noble gas" : "#eaeaea",
|
30 |
+
"nonmetal" : "#f1d4Af",
|
31 |
+
"transition metal" : "#599d7A",
|
32 |
+
}
|
33 |
+
|
34 |
+
TOOLTIPS = [
|
35 |
+
("Name", "@name"),
|
36 |
+
("Atomic number", "@{atomic number}"),
|
37 |
+
("Atomic mass", "@{atomic mass}"),
|
38 |
+
("Type", "@metal"),
|
39 |
+
("CPK color", "$color[hex, swatch]:CPK"),
|
40 |
+
("Electronic configuration", "@{electronic configuration}"),
|
41 |
+
]
|
42 |
+
|
43 |
+
p = figure(title="Periodic Table (omitting LA and AC Series)", width=1000, height=450,
|
44 |
+
x_range=groups, y_range=list(reversed(periods)),
|
45 |
+
tools="hover", toolbar_location=None, tooltips=TOOLTIPS)
|
46 |
+
|
47 |
+
r = p.rect("group", "period", 0.95, 0.95, source=df, fill_alpha=0.6, legend_field="metal",
|
48 |
+
color=factor_cmap('metal', palette=list(cmap.values()), factors=list(cmap.keys())))
|
49 |
+
|
50 |
+
text_props = dict(source=df, text_align="left", text_baseline="middle")
|
51 |
+
|
52 |
+
x = dodge("group", -0.4, range=p.x_range)
|
53 |
+
|
54 |
+
p.text(x=x, y="period", text="symbol", text_font_style="bold", **text_props)
|
55 |
+
|
56 |
+
p.text(x=x, y=dodge("period", 0.3, range=p.y_range), text="atomic number",
|
57 |
+
text_font_size="11px", **text_props)
|
58 |
+
|
59 |
+
p.text(x=x, y=dodge("period", -0.35, range=p.y_range), text="name",
|
60 |
+
text_font_size="7px", **text_props)
|
61 |
+
|
62 |
+
p.text(x=x, y=dodge("period", -0.2, range=p.y_range), text="atomic mass",
|
63 |
+
text_font_size="7px", **text_props)
|
64 |
+
|
65 |
+
p.text(x=["3", "3"], y=["VI", "VII"], text=["LA", "AC"], text_align="center", text_baseline="middle")
|
66 |
+
|
67 |
+
p.outline_line_color = None
|
68 |
+
p.grid.grid_line_color = None
|
69 |
+
p.axis.axis_line_color = None
|
70 |
+
p.axis.major_tick_line_color = None
|
71 |
+
p.axis.major_label_standoff = 0
|
72 |
+
p.legend.orientation = "horizontal"
|
73 |
+
p.legend.location ="top_center"
|
74 |
+
p.hover.renderers = [r] # only hover element boxes
|
75 |
+
|
76 |
+
st.bokeh_chart(p, use_container_width=True)
|
77 |
+
|
78 |
+
# show(p)
|