Spaces:
Running
Running
add equiformer, escn; add leaderboard
Browse files- .gitignore +1 -0
- mlip_arena/models/registry.yaml +8 -8
- mlip_arena/tasks/diatomics/chgnet/homonuclear-diatomics.json +0 -0
- mlip_arena/tasks/diatomics/mace-mp/homonuclear-diatomics.json +0 -0
- mlip_arena/tasks/diatomics/mace-off/homonuclear-diatomics.json +0 -0
- pyproject.toml +56 -4
- serve/app.py +11 -7
- serve/reports/alerts.py +0 -4
- serve/reports/bugs.py +0 -4
- serve/reports/dashboard.py +0 -23
- serve/tasks/homonuclear-diatomics.py +50 -34
- serve/tools/ptable.py +2 -1
.gitignore
CHANGED
@@ -4,6 +4,7 @@ __pycache__/
|
|
4 |
*$py.class
|
5 |
tests/
|
6 |
*.out
|
|
|
7 |
|
8 |
# C extensions
|
9 |
*.so
|
|
|
4 |
*$py.class
|
5 |
tests/
|
6 |
*.out
|
7 |
+
mlip_arena/tasks/*/*/
|
8 |
|
9 |
# C extensions
|
10 |
*.so
|
mlip_arena/models/registry.yaml
CHANGED
@@ -12,11 +12,11 @@ MACE_MP_Medium:
|
|
12 |
gpu-tasks:
|
13 |
- diatomics
|
14 |
|
15 |
-
CHGNet:
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
12 |
gpu-tasks:
|
13 |
- diatomics
|
14 |
|
15 |
+
# CHGNet:
|
16 |
+
# module: chgnet
|
17 |
+
# username: cyrusyc
|
18 |
+
# datetime: 2024-03-25T14:30:00
|
19 |
+
# datasets:
|
20 |
+
# - atomind/mptrj
|
21 |
+
# cpu-tasks:
|
22 |
+
# - diatomics
|
mlip_arena/tasks/diatomics/chgnet/homonuclear-diatomics.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
mlip_arena/tasks/diatomics/mace-mp/homonuclear-diatomics.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
mlip_arena/tasks/diatomics/mace-off/homonuclear-diatomics.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
CHANGED
@@ -26,14 +26,20 @@ classifiers=[
|
|
26 |
"Programming Language :: Python :: 3 :: Only",
|
27 |
]
|
28 |
dependencies=[
|
29 |
-
"torch
|
30 |
"ase",
|
31 |
"torch_dftd>=0.4.0",
|
32 |
"huggingface_hub",
|
33 |
-
"torch-geometric
|
34 |
"safetensors"
|
35 |
]
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
[project.urls]
|
38 |
Homepage = "https://github.com/atomind-ai/mlip-arena"
|
39 |
Issues = "https://github.com/atomind-ai/mlip-arena/issues"
|
@@ -74,7 +80,53 @@ line-length = 88
|
|
74 |
indent-width = 4
|
75 |
|
76 |
[tool.ruff.lint]
|
77 |
-
select = [
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
fixable = ["ALL"]
|
80 |
pydocstyle.convention = "google"
|
|
|
26 |
"Programming Language :: Python :: 3 :: Only",
|
27 |
]
|
28 |
dependencies=[
|
29 |
+
"torch",
|
30 |
"ase",
|
31 |
"torch_dftd>=0.4.0",
|
32 |
"huggingface_hub",
|
33 |
+
"torch-geometric",
|
34 |
"safetensors"
|
35 |
]
|
36 |
|
37 |
+
[project.optional-dependencies]
|
38 |
+
m3gnet = ["matgl", "dgl", "torch<=2.2.1"]
|
39 |
+
mace = ["mace-torch"]
|
40 |
+
chgnet = ["chgnet"]
|
41 |
+
fairchem = ["fairchem"]
|
42 |
+
|
43 |
[project.urls]
|
44 |
Homepage = "https://github.com/atomind-ai/mlip-arena"
|
45 |
Issues = "https://github.com/atomind-ai/mlip-arena/issues"
|
|
|
80 |
indent-width = 4
|
81 |
|
82 |
[tool.ruff.lint]
|
83 |
+
select = [
|
84 |
+
"B", # flake8-bugbear
|
85 |
+
"C4", # flake8-comprehensions
|
86 |
+
"E", # pycodestyle error
|
87 |
+
"EXE", # flake8-executable
|
88 |
+
"F", # pyflakes
|
89 |
+
"FA", # flake8-future-annotations
|
90 |
+
"FBT003", # boolean-positional-value-in-call
|
91 |
+
"FLY", # flynt
|
92 |
+
"I", # isort
|
93 |
+
"ICN", # flake8-import-conventions
|
94 |
+
"PD", # pandas-vet
|
95 |
+
"PERF", # perflint
|
96 |
+
"PIE", # flake8-pie
|
97 |
+
"PL", # pylint
|
98 |
+
"PT", # flake8-pytest-style
|
99 |
+
"PYI", # flakes8-pyi
|
100 |
+
"Q", # flake8-quotes
|
101 |
+
"RET", # flake8-return
|
102 |
+
"RSE", # flake8-raise
|
103 |
+
"RUF", # Ruff-specific rules
|
104 |
+
"SIM", # flake8-simplify
|
105 |
+
"SLOT", # flake8-slots
|
106 |
+
"TCH", # flake8-type-checking
|
107 |
+
"TID", # tidy imports
|
108 |
+
"TID", # flake8-tidy-imports
|
109 |
+
"UP", # pyupgrade
|
110 |
+
"W", # pycodestyle warning
|
111 |
+
"YTT", # flake8-2020
|
112 |
+
]
|
113 |
+
ignore = [
|
114 |
+
"C408", # Unnecessary dict call
|
115 |
+
"PLR", # Design related pylint codes
|
116 |
+
"E501", # Line too long
|
117 |
+
"B028", # No explicit stacklevel
|
118 |
+
"EM101", # Exception must not use a string literal
|
119 |
+
"EM102", # Exception must not use an f-string literal
|
120 |
+
"G004", # f-string in Logging statement
|
121 |
+
"RUF015", # Prefer next(iter())
|
122 |
+
"RET505", # Unnecessary `elif` after `return`
|
123 |
+
"PT004", # Fixture does not return anthing
|
124 |
+
"B017", # pytest.raises
|
125 |
+
"PT011", # pytest.raises
|
126 |
+
"PT012", # pytest.raises"
|
127 |
+
"E741", # ambigous variable naming, i.e. one letter
|
128 |
+
"FBT003", # boolean positional variable in function call
|
129 |
+
"PERF203", # `try`-`except` within a loop incurs performance overhead (no overhead in Py 3.11+)
|
130 |
+
]
|
131 |
fixable = ["ALL"]
|
132 |
pydocstyle.convention = "google"
|
serve/app.py
CHANGED
@@ -4,8 +4,11 @@ st.set_page_config(
|
|
4 |
layout="wide",
|
5 |
page_title="MLIP Arena",
|
6 |
page_icon=":shark:",
|
7 |
-
|
8 |
-
menu_items=
|
|
|
|
|
|
|
9 |
)
|
10 |
|
11 |
# if "logged_in" not in st.session_state:
|
@@ -24,19 +27,19 @@ st.set_page_config(
|
|
24 |
# login_page = st.Page(login, title="Log in", icon=":material/login:")
|
25 |
# logout_page = st.Page(logout, title="Log out", icon=":material/logout:")
|
26 |
|
27 |
-
|
28 |
-
"
|
29 |
)
|
30 |
-
bugs = st.Page("
|
31 |
alerts = st.Page(
|
32 |
-
"
|
33 |
)
|
34 |
|
35 |
search = st.Page("tools/search.py", title="Search", icon=":material/search:")
|
36 |
history = st.Page("tools/history.py", title="History", icon=":material/history:")
|
37 |
ptable = st.Page("tools/ptable.py", title="Periodic table", icon=":material/gradient:")
|
38 |
|
39 |
-
diatomics = st.Page("tasks/homonuclear-diatomics.py", title="Homonuclear diatomics", icon="", default=True)
|
40 |
|
41 |
# if st.session_state.logged_in:
|
42 |
pg = st.navigation(
|
@@ -44,6 +47,7 @@ pg = st.navigation(
|
|
44 |
# "Account": [logout_page],
|
45 |
# "Reports": [dashboard, bugs, alerts],
|
46 |
# "Tools": [search, history, ptable],
|
|
|
47 |
"Tasks": [diatomics],
|
48 |
"Tools": [ptable],
|
49 |
}
|
|
|
4 |
layout="wide",
|
5 |
page_title="MLIP Arena",
|
6 |
page_icon=":shark:",
|
7 |
+
initial_sidebar_state="expanded",
|
8 |
+
menu_items={
|
9 |
+
"About": 'https://github.com/atomind-ai/mlip-arena',
|
10 |
+
"Report a bug": "https://github.com/atomind-ai/mlip-arena/issues/new",
|
11 |
+
}
|
12 |
)
|
13 |
|
14 |
# if "logged_in" not in st.session_state:
|
|
|
27 |
# login_page = st.Page(login, title="Log in", icon=":material/login:")
|
28 |
# logout_page = st.Page(logout, title="Log out", icon=":material/logout:")
|
29 |
|
30 |
+
leaderboard = st.Page(
|
31 |
+
"models/leaderboard.py", title="Leaderboard", icon=":material/trophy:"
|
32 |
)
|
33 |
+
bugs = st.Page("models/bugs.py", title="Bug reports", icon=":material/bug_report:")
|
34 |
alerts = st.Page(
|
35 |
+
"models/alerts.py", title="System alerts", icon=":material/notification_important:"
|
36 |
)
|
37 |
|
38 |
search = st.Page("tools/search.py", title="Search", icon=":material/search:")
|
39 |
history = st.Page("tools/history.py", title="History", icon=":material/history:")
|
40 |
ptable = st.Page("tools/ptable.py", title="Periodic table", icon=":material/gradient:")
|
41 |
|
42 |
+
diatomics = st.Page("tasks/homonuclear-diatomics.py", title="Homonuclear diatomics", icon=":material/target:", default=True)
|
43 |
|
44 |
# if st.session_state.logged_in:
|
45 |
pg = st.navigation(
|
|
|
47 |
# "Account": [logout_page],
|
48 |
# "Reports": [dashboard, bugs, alerts],
|
49 |
# "Tools": [search, history, ptable],
|
50 |
+
"Models": [leaderboard],
|
51 |
"Tasks": [diatomics],
|
52 |
"Tools": [ptable],
|
53 |
}
|
serve/reports/alerts.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
|
3 |
-
|
4 |
-
st.markdown("# Alerts")
|
|
|
|
|
|
|
|
|
|
serve/reports/bugs.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
|
3 |
-
|
4 |
-
st.markdown("# Bugs")
|
|
|
|
|
|
|
|
|
|
serve/reports/dashboard.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import plotly.figure_factory as ff
|
3 |
-
import streamlit as st
|
4 |
-
|
5 |
-
st.markdown("# Dashboard")
|
6 |
-
|
7 |
-
# Add histogram data
|
8 |
-
x1 = np.random.randn(200) - 2
|
9 |
-
x2 = np.random.randn(200)
|
10 |
-
x3 = np.random.randn(200) + 2
|
11 |
-
|
12 |
-
# Group data together
|
13 |
-
hist_data = [x1, x2, x3]
|
14 |
-
|
15 |
-
group_labels = ["Group 1", "Group 2", "Group 3"]
|
16 |
-
|
17 |
-
# Create distplot with custom bin_size
|
18 |
-
fig = ff.create_distplot(
|
19 |
-
hist_data, group_labels, bin_size=[.1, .25, .5]
|
20 |
-
)
|
21 |
-
|
22 |
-
# Plot!
|
23 |
-
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
serve/tasks/homonuclear-diatomics.py
CHANGED
@@ -9,35 +9,46 @@ from ase.data import chemical_symbols
|
|
9 |
from plotly.subplots import make_subplots
|
10 |
from scipy.interpolate import CubicSpline
|
11 |
|
12 |
-
color_sequence = pcolors.qualitative.Plotly
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
st.markdown("# Homonuclear diatomics")
|
17 |
|
18 |
-
|
19 |
container = st.container(border=True)
|
20 |
-
|
21 |
-
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
|
|
|
|
|
|
|
|
|
31 |
|
|
|
32 |
|
|
|
|
|
|
|
33 |
df.drop_duplicates(inplace=True, subset=["name", "method"])
|
34 |
|
|
|
|
|
35 |
for i, symbol in enumerate(chemical_symbols[1:]):
|
36 |
|
37 |
if i % ncols == 0:
|
38 |
cols = st.columns(ncols)
|
39 |
|
40 |
-
|
41 |
rows = df[df["name"] == symbol + symbol]
|
42 |
|
43 |
if rows.empty:
|
@@ -61,23 +72,34 @@ for i, symbol in enumerate(chemical_symbols[1:]):
|
|
61 |
|
62 |
rs = rs[ind]
|
63 |
es = es[ind]
|
64 |
-
|
65 |
-
|
|
|
|
|
66 |
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
if energy_plot:
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
72 |
|
73 |
-
elo = min(elo, ys.min()*1.2, -1)
|
74 |
|
75 |
fig.add_trace(
|
76 |
go.Scatter(
|
77 |
x=xs, y=ys,
|
78 |
mode="lines",
|
79 |
line=dict(
|
80 |
-
color=
|
81 |
width=2,
|
82 |
),
|
83 |
name=method,
|
@@ -85,33 +107,32 @@ for i, symbol in enumerate(chemical_symbols[1:]):
|
|
85 |
secondary_y=False,
|
86 |
)
|
87 |
|
88 |
-
if force_plot:
|
89 |
-
|
90 |
-
ys = cs(xs)
|
91 |
|
92 |
-
flo = min(flo, ys.min()*1.2)
|
93 |
|
94 |
fig.add_trace(
|
95 |
go.Scatter(
|
96 |
x=xs, y=ys,
|
97 |
mode="lines",
|
98 |
line=dict(
|
99 |
-
color=
|
100 |
width=1,
|
101 |
dash="dot",
|
102 |
),
|
103 |
name=method,
|
104 |
-
showlegend=
|
105 |
),
|
106 |
secondary_y=True,
|
107 |
)
|
108 |
|
|
|
109 |
|
110 |
fig.update_layout(
|
111 |
showlegend=True,
|
112 |
-
title_text=f"{
|
113 |
title_x=0.5,
|
114 |
-
# yaxis_range=[ylo, 2*(abs(ylo))],
|
115 |
)
|
116 |
|
117 |
# Set x-axis title
|
@@ -128,21 +149,16 @@ for i, symbol in enumerate(chemical_symbols[1:]):
|
|
128 |
)
|
129 |
)
|
130 |
|
131 |
-
# fig.update_yaxes(title_text="Energy [eV]", secondary_y=False)
|
132 |
-
|
133 |
if force_plot:
|
134 |
|
135 |
fig.update_layout(
|
136 |
yaxis2=dict(
|
137 |
title=dict(text="Force [eV/Å]"),
|
138 |
side="right",
|
139 |
-
range=[flo,
|
140 |
overlaying="y",
|
141 |
tickmode="sync",
|
142 |
),
|
143 |
)
|
144 |
|
145 |
-
# fig.update_yaxes(title_text="Force [eV/Å]", secondary_y=True)
|
146 |
-
|
147 |
-
# cols[i % ncols].title(f"{row['name']}")
|
148 |
cols[i % ncols].plotly_chart(fig, use_container_width=True, height=250)
|
|
|
9 |
from plotly.subplots import make_subplots
|
10 |
from scipy.interpolate import CubicSpline
|
11 |
|
|
|
|
|
|
|
|
|
12 |
st.markdown("# Homonuclear diatomics")
|
13 |
|
14 |
+
st.markdown("### Methods")
|
15 |
container = st.container(border=True)
|
16 |
+
methods = container.multiselect("MLIPs", ["MACE-MP", "Equiformer", "CHGNet", "MACE-OFF", "eSCN"], ["MACE-MP", "Equiformer", "CHGNet", "eSCN"])
|
17 |
+
methods += container.multiselect("DFT Methods", ["GPAW"], [])
|
18 |
|
19 |
+
st.markdown("### Settings")
|
20 |
+
vis = st.container(border=True)
|
21 |
+
energy_plot = vis.checkbox("Show energy curves", value=True)
|
22 |
+
force_plot = vis.checkbox("Show force curves", value=True)
|
23 |
+
ncols = vis.select_slider("Number of columns", options=[1, 2, 3, 4], value=3)
|
24 |
|
25 |
+
# Get all attributes from pcolors.qualitative
|
26 |
+
all_attributes = dir(pcolors.qualitative)
|
27 |
+
color_palettes = {attr: getattr(pcolors.qualitative, attr) for attr in all_attributes if isinstance(getattr(pcolors.qualitative, attr), list)}
|
28 |
+
color_palettes.pop("__all__", None)
|
29 |
|
30 |
+
palette_names = list(color_palettes.keys())
|
31 |
+
palette_colors = list(color_palettes.values())
|
32 |
|
33 |
+
palette_name = vis.selectbox(
|
34 |
+
"Color sequence",
|
35 |
+
options=palette_names, index=22
|
36 |
+
)
|
37 |
|
38 |
+
color_sequence = color_palettes[palette_name] # type: ignore
|
39 |
|
40 |
+
DATA_DIR = Path("mlip_arena/tasks/diatomics")
|
41 |
+
dfs = [pd.read_json(DATA_DIR / method.lower() / "homonuclear-diatomics.json") for method in methods]
|
42 |
+
df = pd.concat(dfs, ignore_index=True)
|
43 |
df.drop_duplicates(inplace=True, subset=["name", "method"])
|
44 |
|
45 |
+
method_color_mapping = {method: color_sequence[i % len(color_sequence)] for i, method in enumerate(df["method"].unique())}
|
46 |
+
|
47 |
for i, symbol in enumerate(chemical_symbols[1:]):
|
48 |
|
49 |
if i % ncols == 0:
|
50 |
cols = st.columns(ncols)
|
51 |
|
|
|
52 |
rows = df[df["name"] == symbol + symbol]
|
53 |
|
54 |
if rows.empty:
|
|
|
72 |
|
73 |
rs = rs[ind]
|
74 |
es = es[ind]
|
75 |
+
if "GPAW" not in method:
|
76 |
+
es = es - es[-1]
|
77 |
+
else:
|
78 |
+
pass
|
79 |
|
80 |
+
if "GPAW" not in method:
|
81 |
+
fs = fs[ind]
|
82 |
+
|
83 |
+
if "GPAW" in method:
|
84 |
+
xs = np.linspace(rs.min()*0.99, rs.max()*1.01, int(5e2))
|
85 |
+
else:
|
86 |
+
xs = rs
|
87 |
|
88 |
if energy_plot:
|
89 |
+
if "GPAW" in method:
|
90 |
+
cs = CubicSpline(rs, es)
|
91 |
+
ys = cs(xs)
|
92 |
+
else:
|
93 |
+
ys = es
|
94 |
|
95 |
+
elo = min(elo, max(ys.min()*1.2, -15), -1)
|
96 |
|
97 |
fig.add_trace(
|
98 |
go.Scatter(
|
99 |
x=xs, y=ys,
|
100 |
mode="lines",
|
101 |
line=dict(
|
102 |
+
color=method_color_mapping[method],
|
103 |
width=2,
|
104 |
),
|
105 |
name=method,
|
|
|
107 |
secondary_y=False,
|
108 |
)
|
109 |
|
110 |
+
if force_plot and "GPAW" not in method:
|
111 |
+
ys = fs
|
|
|
112 |
|
113 |
+
flo = min(flo, max(ys.min()*1.2, -50))
|
114 |
|
115 |
fig.add_trace(
|
116 |
go.Scatter(
|
117 |
x=xs, y=ys,
|
118 |
mode="lines",
|
119 |
line=dict(
|
120 |
+
color=method_color_mapping[method],
|
121 |
width=1,
|
122 |
dash="dot",
|
123 |
),
|
124 |
name=method,
|
125 |
+
showlegend=not energy_plot,
|
126 |
),
|
127 |
secondary_y=True,
|
128 |
)
|
129 |
|
130 |
+
name = f"{symbol}-{symbol}"
|
131 |
|
132 |
fig.update_layout(
|
133 |
showlegend=True,
|
134 |
+
title_text=f"{name}",
|
135 |
title_x=0.5,
|
|
|
136 |
)
|
137 |
|
138 |
# Set x-axis title
|
|
|
149 |
)
|
150 |
)
|
151 |
|
|
|
|
|
152 |
if force_plot:
|
153 |
|
154 |
fig.update_layout(
|
155 |
yaxis2=dict(
|
156 |
title=dict(text="Force [eV/Å]"),
|
157 |
side="right",
|
158 |
+
range=[flo, 1.5*abs(flo)],
|
159 |
overlaying="y",
|
160 |
tickmode="sync",
|
161 |
),
|
162 |
)
|
163 |
|
|
|
|
|
|
|
164 |
cols[i % ncols].plotly_chart(fig, use_container_width=True, height=250)
|
serve/tools/ptable.py
CHANGED
@@ -76,7 +76,8 @@ def display_element_details():
|
|
76 |
if st.button("Back to Periodic Table"):
|
77 |
st.session_state.selected_element = None
|
78 |
st.session_state.selected_name = None
|
79 |
-
st.
|
|
|
80 |
|
81 |
|
82 |
st.title("Periodic Table")
|
|
|
76 |
if st.button("Back to Periodic Table"):
|
77 |
st.session_state.selected_element = None
|
78 |
st.session_state.selected_name = None
|
79 |
+
st.rerun()
|
80 |
+
# st.experimental_rerun()
|
81 |
|
82 |
|
83 |
st.title("Periodic Table")
|