cyrusyc commited on
Commit
221dfe3
1 Parent(s): 89bc52a

add equiformer, escn; add leaderboard

Browse files
.gitignore CHANGED
@@ -4,6 +4,7 @@ __pycache__/
4
  *$py.class
5
  tests/
6
  *.out
 
7
 
8
  # C extensions
9
  *.so
 
4
  *$py.class
5
  tests/
6
  *.out
7
+ mlip_arena/tasks/*/*/
8
 
9
  # C extensions
10
  *.so
mlip_arena/models/registry.yaml CHANGED
@@ -12,11 +12,11 @@ MACE_MP_Medium:
12
  gpu-tasks:
13
  - diatomics
14
 
15
- CHGNet:
16
- module: chgnet
17
- username: cyrusyc
18
- datetime: 2024-03-25T14:30:00
19
- datasets:
20
- - atomind/mptrj
21
- cpu-tasks:
22
- - diatomics
 
12
  gpu-tasks:
13
  - diatomics
14
 
15
+ # CHGNet:
16
+ # module: chgnet
17
+ # username: cyrusyc
18
+ # datetime: 2024-03-25T14:30:00
19
+ # datasets:
20
+ # - atomind/mptrj
21
+ # cpu-tasks:
22
+ # - diatomics
mlip_arena/tasks/diatomics/chgnet/homonuclear-diatomics.json CHANGED
The diff for this file is too large to render. See raw diff
 
mlip_arena/tasks/diatomics/mace-mp/homonuclear-diatomics.json CHANGED
The diff for this file is too large to render. See raw diff
 
mlip_arena/tasks/diatomics/mace-off/homonuclear-diatomics.json ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -26,14 +26,20 @@ classifiers=[
26
  "Programming Language :: Python :: 3 :: Only",
27
  ]
28
  dependencies=[
29
- "torch>=2.0.0",
30
  "ase",
31
  "torch_dftd>=0.4.0",
32
  "huggingface_hub",
33
- "torch-geometric>=2.5.2",
34
  "safetensors"
35
  ]
36
 
 
 
 
 
 
 
37
  [project.urls]
38
  Homepage = "https://github.com/atomind-ai/mlip-arena"
39
  Issues = "https://github.com/atomind-ai/mlip-arena/issues"
@@ -74,7 +80,53 @@ line-length = 88
74
  indent-width = 4
75
 
76
  [tool.ruff.lint]
77
- select = ["ALL"]
78
- ignore = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  fixable = ["ALL"]
80
  pydocstyle.convention = "google"
 
26
  "Programming Language :: Python :: 3 :: Only",
27
  ]
28
  dependencies=[
29
+ "torch",
30
  "ase",
31
  "torch_dftd>=0.4.0",
32
  "huggingface_hub",
33
+ "torch-geometric",
34
  "safetensors"
35
  ]
36
 
37
+ [project.optional-dependencies]
38
+ m3gnet = ["matgl", "dgl", "torch<=2.2.1"]
39
+ mace = ["mace-torch"]
40
+ chgnet = ["chgnet"]
41
+ fairchem = ["fairchem"]
42
+
43
  [project.urls]
44
  Homepage = "https://github.com/atomind-ai/mlip-arena"
45
  Issues = "https://github.com/atomind-ai/mlip-arena/issues"
 
80
  indent-width = 4
81
 
82
  [tool.ruff.lint]
83
+ select = [
84
+ "B", # flake8-bugbear
85
+ "C4", # flake8-comprehensions
86
+ "E", # pycodestyle error
87
+ "EXE", # flake8-executable
88
+ "F", # pyflakes
89
+ "FA", # flake8-future-annotations
90
+ "FBT003", # boolean-positional-value-in-call
91
+ "FLY", # flynt
92
+ "I", # isort
93
+ "ICN", # flake8-import-conventions
94
+ "PD", # pandas-vet
95
+ "PERF", # perflint
96
+ "PIE", # flake8-pie
97
+ "PL", # pylint
98
+ "PT", # flake8-pytest-style
99
+ "PYI", # flakes8-pyi
100
+ "Q", # flake8-quotes
101
+ "RET", # flake8-return
102
+ "RSE", # flake8-raise
103
+ "RUF", # Ruff-specific rules
104
+ "SIM", # flake8-simplify
105
+ "SLOT", # flake8-slots
106
+ "TCH", # flake8-type-checking
107
+ "TID", # tidy imports
108
+ "TID", # flake8-tidy-imports
109
+ "UP", # pyupgrade
110
+ "W", # pycodestyle warning
111
+ "YTT", # flake8-2020
112
+ ]
113
+ ignore = [
114
+ "C408", # Unnecessary dict call
115
+ "PLR", # Design related pylint codes
116
+ "E501", # Line too long
117
+ "B028", # No explicit stacklevel
118
+ "EM101", # Exception must not use a string literal
119
+ "EM102", # Exception must not use an f-string literal
120
+ "G004", # f-string in Logging statement
121
+ "RUF015", # Prefer next(iter())
122
+ "RET505", # Unnecessary `elif` after `return`
123
+ "PT004", # Fixture does not return anthing
124
+ "B017", # pytest.raises
125
+ "PT011", # pytest.raises
126
+ "PT012", # pytest.raises"
127
+ "E741", # ambigous variable naming, i.e. one letter
128
+ "FBT003", # boolean positional variable in function call
129
+ "PERF203", # `try`-`except` within a loop incurs performance overhead (no overhead in Py 3.11+)
130
+ ]
131
  fixable = ["ALL"]
132
  pydocstyle.convention = "google"
serve/app.py CHANGED
@@ -4,8 +4,11 @@ st.set_page_config(
4
  layout="wide",
5
  page_title="MLIP Arena",
6
  page_icon=":shark:",
7
- # initial_sidebar_state="expanded",
8
- menu_items=None
 
 
 
9
  )
10
 
11
  # if "logged_in" not in st.session_state:
@@ -24,19 +27,19 @@ st.set_page_config(
24
  # login_page = st.Page(login, title="Log in", icon=":material/login:")
25
  # logout_page = st.Page(logout, title="Log out", icon=":material/logout:")
26
 
27
- dashboard = st.Page(
28
- "reports/dashboard.py", title="Dashboard", icon=":material/dashboard:"
29
  )
30
- bugs = st.Page("reports/bugs.py", title="Bug reports", icon=":material/bug_report:")
31
  alerts = st.Page(
32
- "reports/alerts.py", title="System alerts", icon=":material/notification_important:"
33
  )
34
 
35
  search = st.Page("tools/search.py", title="Search", icon=":material/search:")
36
  history = st.Page("tools/history.py", title="History", icon=":material/history:")
37
  ptable = st.Page("tools/ptable.py", title="Periodic table", icon=":material/gradient:")
38
 
39
- diatomics = st.Page("tasks/homonuclear-diatomics.py", title="Homonuclear diatomics", icon="", default=True)
40
 
41
  # if st.session_state.logged_in:
42
  pg = st.navigation(
@@ -44,6 +47,7 @@ pg = st.navigation(
44
  # "Account": [logout_page],
45
  # "Reports": [dashboard, bugs, alerts],
46
  # "Tools": [search, history, ptable],
 
47
  "Tasks": [diatomics],
48
  "Tools": [ptable],
49
  }
 
4
  layout="wide",
5
  page_title="MLIP Arena",
6
  page_icon=":shark:",
7
+ initial_sidebar_state="expanded",
8
+ menu_items={
9
+ "About": 'https://github.com/atomind-ai/mlip-arena',
10
+ "Report a bug": "https://github.com/atomind-ai/mlip-arena/issues/new",
11
+ }
12
  )
13
 
14
  # if "logged_in" not in st.session_state:
 
27
  # login_page = st.Page(login, title="Log in", icon=":material/login:")
28
  # logout_page = st.Page(logout, title="Log out", icon=":material/logout:")
29
 
30
+ leaderboard = st.Page(
31
+ "models/leaderboard.py", title="Leaderboard", icon=":material/trophy:"
32
  )
33
+ bugs = st.Page("models/bugs.py", title="Bug reports", icon=":material/bug_report:")
34
  alerts = st.Page(
35
+ "models/alerts.py", title="System alerts", icon=":material/notification_important:"
36
  )
37
 
38
  search = st.Page("tools/search.py", title="Search", icon=":material/search:")
39
  history = st.Page("tools/history.py", title="History", icon=":material/history:")
40
  ptable = st.Page("tools/ptable.py", title="Periodic table", icon=":material/gradient:")
41
 
42
+ diatomics = st.Page("tasks/homonuclear-diatomics.py", title="Homonuclear diatomics", icon=":material/target:", default=True)
43
 
44
  # if st.session_state.logged_in:
45
  pg = st.navigation(
 
47
  # "Account": [logout_page],
48
  # "Reports": [dashboard, bugs, alerts],
49
  # "Tools": [search, history, ptable],
50
+ "Models": [leaderboard],
51
  "Tasks": [diatomics],
52
  "Tools": [ptable],
53
  }
serve/reports/alerts.py DELETED
@@ -1,4 +0,0 @@
1
- import streamlit as st
2
-
3
-
4
- st.markdown("# Alerts")
 
 
 
 
 
serve/reports/bugs.py DELETED
@@ -1,4 +0,0 @@
1
- import streamlit as st
2
-
3
-
4
- st.markdown("# Bugs")
 
 
 
 
 
serve/reports/dashboard.py DELETED
@@ -1,23 +0,0 @@
1
- import numpy as np
2
- import plotly.figure_factory as ff
3
- import streamlit as st
4
-
5
- st.markdown("# Dashboard")
6
-
7
- # Add histogram data
8
- x1 = np.random.randn(200) - 2
9
- x2 = np.random.randn(200)
10
- x3 = np.random.randn(200) + 2
11
-
12
- # Group data together
13
- hist_data = [x1, x2, x3]
14
-
15
- group_labels = ["Group 1", "Group 2", "Group 3"]
16
-
17
- # Create distplot with custom bin_size
18
- fig = ff.create_distplot(
19
- hist_data, group_labels, bin_size=[.1, .25, .5]
20
- )
21
-
22
- # Plot!
23
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
serve/tasks/homonuclear-diatomics.py CHANGED
@@ -9,35 +9,46 @@ from ase.data import chemical_symbols
9
  from plotly.subplots import make_subplots
10
  from scipy.interpolate import CubicSpline
11
 
12
- color_sequence = pcolors.qualitative.Plotly
13
-
14
-
15
-
16
  st.markdown("# Homonuclear diatomics")
17
 
18
- # button to toggle plots
19
  container = st.container(border=True)
20
- energy_plot = container.checkbox("Show energy curves", value=True)
21
- force_plot = container.checkbox("Show force curves", value=True)
22
 
23
- ncols = 2
 
 
 
 
24
 
25
- DATA_DIR = Path("mlip_arena/tasks/diatomics")
26
- mlips = ["MACE-MP", "CHGNet"]
 
 
27
 
28
- dfs = [pd.read_json(DATA_DIR / mlip.lower() / "homonuclear-diatomics.json") for mlip in mlips]
29
- df = pd.concat(dfs, ignore_index=True)
30
 
 
 
 
 
31
 
 
32
 
 
 
 
33
  df.drop_duplicates(inplace=True, subset=["name", "method"])
34
 
 
 
35
  for i, symbol in enumerate(chemical_symbols[1:]):
36
 
37
  if i % ncols == 0:
38
  cols = st.columns(ncols)
39
 
40
-
41
  rows = df[df["name"] == symbol + symbol]
42
 
43
  if rows.empty:
@@ -61,23 +72,34 @@ for i, symbol in enumerate(chemical_symbols[1:]):
61
 
62
  rs = rs[ind]
63
  es = es[ind]
64
- es = es - es[-1]
65
- fs = fs[ind]
 
 
66
 
67
- xs = np.linspace(rs.min()*0.99, rs.max()*1.01, int(5e2))
 
 
 
 
 
 
68
 
69
  if energy_plot:
70
- cs = CubicSpline(rs, es)
71
- ys = cs(xs)
 
 
 
72
 
73
- elo = min(elo, ys.min()*1.2, -1)
74
 
75
  fig.add_trace(
76
  go.Scatter(
77
  x=xs, y=ys,
78
  mode="lines",
79
  line=dict(
80
- color=color_sequence[j % len(color_sequence)],
81
  width=2,
82
  ),
83
  name=method,
@@ -85,33 +107,32 @@ for i, symbol in enumerate(chemical_symbols[1:]):
85
  secondary_y=False,
86
  )
87
 
88
- if force_plot:
89
- cs = CubicSpline(rs, fs)
90
- ys = cs(xs)
91
 
92
- flo = min(flo, ys.min()*1.2)
93
 
94
  fig.add_trace(
95
  go.Scatter(
96
  x=xs, y=ys,
97
  mode="lines",
98
  line=dict(
99
- color=color_sequence[j % len(color_sequence)],
100
  width=1,
101
  dash="dot",
102
  ),
103
  name=method,
104
- showlegend=False if energy_plot else True,
105
  ),
106
  secondary_y=True,
107
  )
108
 
 
109
 
110
  fig.update_layout(
111
  showlegend=True,
112
- title_text=f"{symbol}-{symbol}",
113
  title_x=0.5,
114
- # yaxis_range=[ylo, 2*(abs(ylo))],
115
  )
116
 
117
  # Set x-axis title
@@ -128,21 +149,16 @@ for i, symbol in enumerate(chemical_symbols[1:]):
128
  )
129
  )
130
 
131
- # fig.update_yaxes(title_text="Energy [eV]", secondary_y=False)
132
-
133
  if force_plot:
134
 
135
  fig.update_layout(
136
  yaxis2=dict(
137
  title=dict(text="Force [eV/Å]"),
138
  side="right",
139
- range=[flo, 2*(abs(flo))],
140
  overlaying="y",
141
  tickmode="sync",
142
  ),
143
  )
144
 
145
- # fig.update_yaxes(title_text="Force [eV/Å]", secondary_y=True)
146
-
147
- # cols[i % ncols].title(f"{row['name']}")
148
  cols[i % ncols].plotly_chart(fig, use_container_width=True, height=250)
 
9
  from plotly.subplots import make_subplots
10
  from scipy.interpolate import CubicSpline
11
 
 
 
 
 
12
  st.markdown("# Homonuclear diatomics")
13
 
14
+ st.markdown("### Methods")
15
  container = st.container(border=True)
16
+ methods = container.multiselect("MLIPs", ["MACE-MP", "Equiformer", "CHGNet", "MACE-OFF", "eSCN"], ["MACE-MP", "Equiformer", "CHGNet", "eSCN"])
17
+ methods += container.multiselect("DFT Methods", ["GPAW"], [])
18
 
19
+ st.markdown("### Settings")
20
+ vis = st.container(border=True)
21
+ energy_plot = vis.checkbox("Show energy curves", value=True)
22
+ force_plot = vis.checkbox("Show force curves", value=True)
23
+ ncols = vis.select_slider("Number of columns", options=[1, 2, 3, 4], value=3)
24
 
25
+ # Get all attributes from pcolors.qualitative
26
+ all_attributes = dir(pcolors.qualitative)
27
+ color_palettes = {attr: getattr(pcolors.qualitative, attr) for attr in all_attributes if isinstance(getattr(pcolors.qualitative, attr), list)}
28
+ color_palettes.pop("__all__", None)
29
 
30
+ palette_names = list(color_palettes.keys())
31
+ palette_colors = list(color_palettes.values())
32
 
33
+ palette_name = vis.selectbox(
34
+ "Color sequence",
35
+ options=palette_names, index=22
36
+ )
37
 
38
+ color_sequence = color_palettes[palette_name] # type: ignore
39
 
40
+ DATA_DIR = Path("mlip_arena/tasks/diatomics")
41
+ dfs = [pd.read_json(DATA_DIR / method.lower() / "homonuclear-diatomics.json") for method in methods]
42
+ df = pd.concat(dfs, ignore_index=True)
43
  df.drop_duplicates(inplace=True, subset=["name", "method"])
44
 
45
+ method_color_mapping = {method: color_sequence[i % len(color_sequence)] for i, method in enumerate(df["method"].unique())}
46
+
47
  for i, symbol in enumerate(chemical_symbols[1:]):
48
 
49
  if i % ncols == 0:
50
  cols = st.columns(ncols)
51
 
 
52
  rows = df[df["name"] == symbol + symbol]
53
 
54
  if rows.empty:
 
72
 
73
  rs = rs[ind]
74
  es = es[ind]
75
+ if "GPAW" not in method:
76
+ es = es - es[-1]
77
+ else:
78
+ pass
79
 
80
+ if "GPAW" not in method:
81
+ fs = fs[ind]
82
+
83
+ if "GPAW" in method:
84
+ xs = np.linspace(rs.min()*0.99, rs.max()*1.01, int(5e2))
85
+ else:
86
+ xs = rs
87
 
88
  if energy_plot:
89
+ if "GPAW" in method:
90
+ cs = CubicSpline(rs, es)
91
+ ys = cs(xs)
92
+ else:
93
+ ys = es
94
 
95
+ elo = min(elo, max(ys.min()*1.2, -15), -1)
96
 
97
  fig.add_trace(
98
  go.Scatter(
99
  x=xs, y=ys,
100
  mode="lines",
101
  line=dict(
102
+ color=method_color_mapping[method],
103
  width=2,
104
  ),
105
  name=method,
 
107
  secondary_y=False,
108
  )
109
 
110
+ if force_plot and "GPAW" not in method:
111
+ ys = fs
 
112
 
113
+ flo = min(flo, max(ys.min()*1.2, -50))
114
 
115
  fig.add_trace(
116
  go.Scatter(
117
  x=xs, y=ys,
118
  mode="lines",
119
  line=dict(
120
+ color=method_color_mapping[method],
121
  width=1,
122
  dash="dot",
123
  ),
124
  name=method,
125
+ showlegend=not energy_plot,
126
  ),
127
  secondary_y=True,
128
  )
129
 
130
+ name = f"{symbol}-{symbol}"
131
 
132
  fig.update_layout(
133
  showlegend=True,
134
+ title_text=f"{name}",
135
  title_x=0.5,
 
136
  )
137
 
138
  # Set x-axis title
 
149
  )
150
  )
151
 
 
 
152
  if force_plot:
153
 
154
  fig.update_layout(
155
  yaxis2=dict(
156
  title=dict(text="Force [eV/Å]"),
157
  side="right",
158
+ range=[flo, 1.5*abs(flo)],
159
  overlaying="y",
160
  tickmode="sync",
161
  ),
162
  )
163
 
 
 
 
164
  cols[i % ncols].plotly_chart(fig, use_container_width=True, height=250)
serve/tools/ptable.py CHANGED
@@ -76,7 +76,8 @@ def display_element_details():
76
  if st.button("Back to Periodic Table"):
77
  st.session_state.selected_element = None
78
  st.session_state.selected_name = None
79
- st.experimental_rerun()
 
80
 
81
 
82
  st.title("Periodic Table")
 
76
  if st.button("Back to Periodic Table"):
77
  st.session_state.selected_element = None
78
  st.session_state.selected_name = None
79
+ st.rerun()
80
+ # st.experimental_rerun()
81
 
82
 
83
  st.title("Periodic Table")