cyrusyc commited on
Commit
3eda6d3
1 Parent(s): 7cbf186

update ptable, add statbility, update pyproject.toml

Browse files
mlip_arena/models/registry.yaml CHANGED
@@ -3,6 +3,7 @@
3
  MACE-MP(M):
4
  module: externals
5
  class: MACE_MP_Medium
 
6
  username: cyrusyc # HF username
7
  last-update: 2024-03-25T14:30:00
8
  datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
@@ -20,6 +21,7 @@ MACE-MP(M):
20
  CHGNet:
21
  module: externals
22
  class: CHGNet
 
23
  username: cyrusyc
24
  last-update: 2024-07-08T00:00:00
25
  datetime: 2024-07-08T00:00:00
@@ -31,6 +33,7 @@ CHGNet:
31
  EquiformerV2(OC22):
32
  module: externals
33
  class: EquiformerV2
 
34
  username: cyrusyc
35
  last-update: 2024-07-08T00:00:00
36
  datetime: 2024-07-08T00:00:00
@@ -42,6 +45,7 @@ EquiformerV2(OC22):
42
  eSCN(OC20):
43
  module: externals
44
  class: eSCN
 
45
  username: cyrusyc
46
  last-update: 2024-07-08T00:00:00
47
  datetime: 2024-07-08T00:00:00
 
3
  MACE-MP(M):
4
  module: externals
5
  class: MACE_MP_Medium
6
+ family: mace
7
  username: cyrusyc # HF username
8
  last-update: 2024-03-25T14:30:00
9
  datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
 
21
  CHGNet:
22
  module: externals
23
  class: CHGNet
24
+ family: chgnet
25
  username: cyrusyc
26
  last-update: 2024-07-08T00:00:00
27
  datetime: 2024-07-08T00:00:00
 
33
  EquiformerV2(OC22):
34
  module: externals
35
  class: EquiformerV2
36
+ family: equiformer
37
  username: cyrusyc
38
  last-update: 2024-07-08T00:00:00
39
  datetime: 2024-07-08T00:00:00
 
45
  eSCN(OC20):
46
  module: externals
47
  class: eSCN
48
+ family: escn
49
  username: cyrusyc
50
  last-update: 2024-07-08T00:00:00
51
  datetime: 2024-07-08T00:00:00
mlip_arena/tasks/stability/run.py CHANGED
@@ -119,7 +119,7 @@ def _get_ensemble_defaults(
119
  ase_md_kwargs.pop("externalstress", None)
120
  elif ensemble == "npt":
121
  ase_md_kwargs["temperature_K"] = t_schedule[0]
122
- ase_md_kwargs["externalstress"] = p_schedule[0] * 1e3 * units.bar
123
 
124
  if isinstance(dynamics, str) and dynamics.lower() == "langevin":
125
  ase_md_kwargs["friction"] = ase_md_kwargs.get(
 
119
  ase_md_kwargs.pop("externalstress", None)
120
  elif ensemble == "npt":
121
  ase_md_kwargs["temperature_K"] = t_schedule[0]
122
+ ase_md_kwargs["externalstress"] = p_schedule[0] # * 1e3 * units.bar
123
 
124
  if isinstance(dynamics, str) and dynamics.lower() == "langevin":
125
  ase_md_kwargs["friction"] = ase_md_kwargs.get(
pyproject.toml CHANGED
@@ -26,12 +26,13 @@ classifiers=[
26
  "Programming Language :: Python :: 3 :: Only",
27
  ]
28
  dependencies=[
29
- "torch",
30
  "ase",
 
31
  "torch_dftd>=0.4.0",
32
  "huggingface_hub",
33
  "torch-geometric",
34
- "safetensors"
 
35
  ]
36
 
37
  [project.optional-dependencies]
@@ -39,6 +40,11 @@ m3gnet = ["matgl", "dgl", "torch<=2.2.1"]
39
  mace = ["mace-torch"]
40
  chgnet = ["chgnet"]
41
  fairchem = ["fairchem"]
 
 
 
 
 
42
 
43
  [project.urls]
44
  Homepage = "https://github.com/atomind-ai/mlip-arena"
 
26
  "Programming Language :: Python :: 3 :: Only",
27
  ]
28
  dependencies=[
 
29
  "ase",
30
+ "torch",
31
  "torch_dftd>=0.4.0",
32
  "huggingface_hub",
33
  "torch-geometric",
34
+ "safetensors",
35
+ "pymatgen"
36
  ]
37
 
38
  [project.optional-dependencies]
 
40
  mace = ["mace-torch"]
41
  chgnet = ["chgnet"]
42
  fairchem = ["fairchem"]
43
+ app = [
44
+ "streamlit",
45
+ "plotly",
46
+ "bokeh==2.4.3",
47
+ ]
48
 
49
  [project.urls]
50
  Homepage = "https://github.com/atomind-ai/mlip-arena"
serve/app.py CHANGED
@@ -40,6 +40,8 @@ history = st.Page("tools/history.py", title="History", icon=":material/history:"
40
  ptable = st.Page("tools/ptable.py", title="Periodic table", icon=":material/gradient:")
41
 
42
  diatomics = st.Page("tasks/homonuclear-diatomics.py", title="Homonuclear diatomics", icon=":material/target:", default=True)
 
 
43
 
44
  # if st.session_state.logged_in:
45
  pg = st.navigation(
@@ -48,7 +50,7 @@ pg = st.navigation(
48
  # "Reports": [dashboard, bugs, alerts],
49
  # "Tools": [search, history, ptable],
50
  "": [leaderboard],
51
- "Tasks": [diatomics],
52
  "Tools": [ptable],
53
  }
54
  )
 
40
  ptable = st.Page("tools/ptable.py", title="Periodic table", icon=":material/gradient:")
41
 
42
  diatomics = st.Page("tasks/homonuclear-diatomics.py", title="Homonuclear diatomics", icon=":material/target:", default=True)
43
+ stability = st.Page("tasks/stability.py", title="Stability", icon=":material/target:")
44
+
45
 
46
  # if st.session_state.logged_in:
47
  pg = st.navigation(
 
50
  # "Reports": [dashboard, bugs, alerts],
51
  # "Tools": [search, history, ptable],
52
  "": [leaderboard],
53
+ "Tasks": [diatomics, stability],
54
  "Tools": [ptable],
55
  }
56
  )
serve/tasks/homonuclear-diatomics.py CHANGED
@@ -9,6 +9,8 @@ from ase.data import chemical_symbols
9
  from plotly.subplots import make_subplots
10
  from scipy.interpolate import CubicSpline
11
 
 
 
12
  st.markdown("# Homonuclear diatomics")
13
 
14
  st.markdown("### Methods")
 
9
  from plotly.subplots import make_subplots
10
  from scipy.interpolate import CubicSpline
11
 
12
+ from mlip_arena.models.utils import MLIPMap
13
+
14
  st.markdown("# Homonuclear diatomics")
15
 
16
  st.markdown("### Methods")
serve/tasks/stability.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import plotly.colors as pcolors
6
+ import plotly.graph_objects as go
7
+ import streamlit as st
8
+ from ase.data import chemical_symbols
9
+ from ase.io import read, write
10
+ from plotly.subplots import make_subplots
11
+ from scipy.interpolate import CubicSpline
12
+
13
+ from mlip_arena.models.utils import MLIPMap
14
+
15
+ st.markdown("# Stability")
16
+
17
+ st.markdown("### Methods")
18
+ container = st.container(border=True)
19
+ methods = container.multiselect("MLIPs", ["MACE-MP", "Equiformer", "CHGNet", "MACE-OFF", "eSCN", "ALIGNN"], ["MACE-MP", "Equiformer", "CHGNet", "eSCN", "ALIGNN"])
20
+
21
+
22
+ DATA_DIR = Path("mlip_arena/tasks/stability")
23
+
24
+
25
+
26
+
27
+
28
+
serve/tools/ptable.py CHANGED
@@ -1,93 +1,78 @@
1
- import streamlit as st
2
- from ase.data import chemical_symbols
3
- from pymatgen.core import Element
4
-
5
- elements = [Element.from_Z(z) for z in range(1, 119)]
6
-
7
- # Define the number of rows and columns in the periodic table
8
- rows = 9 # There are 7 rows in the conventional periodic table
9
- columns = 18
10
-
11
- # Define a function to display the periodic table
12
- def display_periodic_table():
13
- # elements = [
14
- # (element, element) for element in chemical_symbols[1:]
15
- # ]
16
-
17
- # cols = st.columns(18, gap='small', vertical_alignment='bottom') # Create 18 columns for the periodic table layout
18
-
19
- row = 0
20
- for element in elements:
21
- symbol = element.symbol
22
- atomic_number = element.Z
23
- group = element.group
24
-
25
- if element.row > row:
26
- cols = st.columns(columns, gap='small', vertical_alignment='bottom')
27
- row = element.row
28
-
29
- if element.block == 'f':
30
- continue
31
-
32
- with cols[group - 1]:
33
- if st.button(symbol, use_container_width=True):
34
- st.session_state.selected_element = symbol
35
- st.session_state.selected_name = symbol
36
- st.rerun()
37
- # st.experimental_rerun()
38
-
39
- for element in elements:
40
- symbol = element.symbol
41
- atomic_number = element.Z
42
- group = element.group
43
-
44
- if element.row > row:
45
- cols = st.columns(columns, gap='small', vertical_alignment='bottom')
46
- row = element.row
47
-
48
- if element.block == 'f':
49
- noble = Element.from_row_and_group(row-1, 18)
50
- row += 2
51
- group += atomic_number - noble.Z - 2
52
- else:
53
- continue
54
-
55
- with cols[group - 1]:
56
- if st.button(symbol, use_container_width=True):
57
- st.session_state.selected_element = symbol
58
- st.session_state.selected_name = symbol
59
- st.rerun()
60
- # st.experimental_rerun()
61
-
62
-
63
- # for idx, (symbol, name) in enumerate(elements):
64
- # with cols[idx % 18]: # Place each element in the correct column
65
- # if st.button(symbol, use_container_width=True):
66
- # st.session_state.selected_element = symbol
67
- # st.session_state.selected_name = name
68
- # st.experimental_rerun()
69
-
70
- # Define a function to display the details of an element
71
- def display_element_details():
72
- symbol = st.session_state.selected_element
73
- name = st.session_state.selected_name
74
- st.write(f"### {name} ({symbol})")
75
- st.write(f"Details about {name} ({symbol}) will be displayed here.")
76
- if st.button("Back to Periodic Table"):
77
- st.session_state.selected_element = None
78
- st.session_state.selected_name = None
79
- st.rerun()
80
- # st.experimental_rerun()
81
-
82
-
83
- st.title("Periodic Table")
84
-
85
- # st.balloons()
86
- if 'selected_element' not in st.session_state:
87
- st.session_state.selected_element = None
88
-
89
- if st.session_state.selected_element:
90
- display_element_details()
91
- else:
92
- display_periodic_table()
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+
3
+ # NOTE: https://stackoverflow.com/questions/77062368/streamlit-bokeh-event-callback-to-get-clicked-values
4
+ # Taptool: https://docs.bokeh.org/en/2.4.2/docs/reference/models/tools.html#taptool
5
+
6
+ import streamlit as st
7
+ from bokeh.plotting import figure
8
+ from bokeh.plotting import figure, show
9
+ from bokeh.sampledata.periodic_table import elements
10
+ from bokeh.transform import dodge, factor_cmap
11
+
12
+ periods = ["I", "II", "III", "IV", "V", "VI", "VII"]
13
+ groups = [str(x) for x in range(1, 19)]
14
+
15
+ df = elements.copy()
16
+ df["atomic mass"] = df["atomic mass"].astype(str)
17
+ df["group"] = df["group"].astype(str)
18
+ df["period"] = [periods[x-1] for x in df.period]
19
+ df = df[df.group != "-"]
20
+ df = df[df.symbol != "Lr"]
21
+ df = df[df.symbol != "Lu"]
22
+
23
+ cmap = {
24
+ "alkali metal" : "#a6cee3",
25
+ "alkaline earth metal" : "#1f78b4",
26
+ "metal" : "#d93b43",
27
+ "halogen" : "#999d9a",
28
+ "metalloid" : "#e08d49",
29
+ "noble gas" : "#eaeaea",
30
+ "nonmetal" : "#f1d4Af",
31
+ "transition metal" : "#599d7A",
32
+ }
33
+
34
+ TOOLTIPS = [
35
+ ("Name", "@name"),
36
+ ("Atomic number", "@{atomic number}"),
37
+ ("Atomic mass", "@{atomic mass}"),
38
+ ("Type", "@metal"),
39
+ ("CPK color", "$color[hex, swatch]:CPK"),
40
+ ("Electronic configuration", "@{electronic configuration}"),
41
+ ]
42
+
43
+ p = figure(title="Periodic Table (omitting LA and AC Series)", width=1000, height=450,
44
+ x_range=groups, y_range=list(reversed(periods)),
45
+ tools="hover", toolbar_location=None, tooltips=TOOLTIPS)
46
+
47
+ r = p.rect("group", "period", 0.95, 0.95, source=df, fill_alpha=0.6, legend_field="metal",
48
+ color=factor_cmap('metal', palette=list(cmap.values()), factors=list(cmap.keys())))
49
+
50
+ text_props = dict(source=df, text_align="left", text_baseline="middle")
51
+
52
+ x = dodge("group", -0.4, range=p.x_range)
53
+
54
+ p.text(x=x, y="period", text="symbol", text_font_style="bold", **text_props)
55
+
56
+ p.text(x=x, y=dodge("period", 0.3, range=p.y_range), text="atomic number",
57
+ text_font_size="11px", **text_props)
58
+
59
+ p.text(x=x, y=dodge("period", -0.35, range=p.y_range), text="name",
60
+ text_font_size="7px", **text_props)
61
+
62
+ p.text(x=x, y=dodge("period", -0.2, range=p.y_range), text="atomic mass",
63
+ text_font_size="7px", **text_props)
64
+
65
+ p.text(x=["3", "3"], y=["VI", "VII"], text=["LA", "AC"], text_align="center", text_baseline="middle")
66
+
67
+ p.outline_line_color = None
68
+ p.grid.grid_line_color = None
69
+ p.axis.axis_line_color = None
70
+ p.axis.major_tick_line_color = None
71
+ p.axis.major_label_standoff = 0
72
+ p.legend.orientation = "horizontal"
73
+ p.legend.location ="top_center"
74
+ p.hover.renderers = [r] # only hover element boxes
75
+
76
+ st.bokeh_chart(p, use_container_width=True)
77
+
78
+ # show(p)