Spaces:
Running
Running
Yuan (Cyrus) Chiang
commited on
Commit
•
52c1bfb
1
Parent(s):
bd8cd88
Add `eqV2_86M_omat_mp_salex` model (#14)
Browse files* refactor external calculators
* add `eqV2_86M_omat_mp_salex`
* change eos test MLIPEnum
* fix bugs for MLIPEnum refactoring
* login huggingface for download ckpt during test
* remove add git credential
* use hf read only token
* rollback `fairchem==1.1.0`
* gracefully xfail for some farichem models
* reverse test installation order
* try adding opt-einsum constraint
* try installing from freeze
* try installing from freeze
* loosen cachepath
* add test pip cache
* bump mace to 0.3.6
* enforce `e3nn==0.4.4`
* add prefect test harness
* change prefect text fixture
* add diatomic curves
* add matgl json
* increase concurrent number
* pytest only on py3.11
* fix system info tuple bug
* use context manager
- .github/workflows/test.yaml +15 -3
- mlip_arena/models/__init__.py +13 -0
- mlip_arena/models/externals/__init__.py +0 -0
- mlip_arena/models/externals/alignn.py +15 -0
- mlip_arena/models/externals/chgnet.py +39 -0
- mlip_arena/models/externals/equiformer.py +55 -0
- mlip_arena/models/externals/escn.py +35 -0
- mlip_arena/models/externals/fairchem.py +124 -0
- mlip_arena/models/externals/mace-mp.py +39 -0
- mlip_arena/models/externals/mace-off.py +39 -0
- mlip_arena/models/externals/matgl.py +18 -0
- mlip_arena/models/externals/orb.py +40 -0
- mlip_arena/models/externals/sevennet.py +16 -0
- mlip_arena/models/registry.yaml +37 -15
- mlip_arena/models/utils.py +1 -14
- mlip_arena/tasks/combustion/{m3gnet → matgl}/hydrogen.json +0 -0
- mlip_arena/tasks/diatomics/fairchem/homonuclear-diatomics.json +3 -0
- mlip_arena/tasks/diatomics/matgl/homonuclear-diatomics.json +3 -0
- mlip_arena/tasks/diatomics/run.ipynb +0 -0
- mlip_arena/tasks/eos/run.py +1 -1
- mlip_arena/tasks/md.py +2 -1
- mlip_arena/tasks/optimize.py +2 -1
- mlip_arena/tasks/run.py +0 -318
- pyproject.toml +4 -3
- serve/leaderboard.py +2 -1
- serve/ranks/homonuclear-diatomics.py +2 -0
- serve/tasks/homonuclear-diatomics.py +3 -2
- tests/test_eos.py +25 -18
- tests/test_external_calculators.py +1 -1
.github/workflows/test.yaml
CHANGED
@@ -7,17 +7,20 @@ jobs:
|
|
7 |
runs-on: ubuntu-latest
|
8 |
|
9 |
strategy:
|
|
|
10 |
matrix:
|
11 |
python-version: ["3.10", "3.11"]
|
|
|
12 |
|
13 |
steps:
|
14 |
- name: Checkout repository
|
15 |
-
uses: actions/checkout@
|
16 |
|
17 |
- name: Set up Python ${{ matrix.python-version }}
|
18 |
-
uses: actions/setup-python@
|
19 |
with:
|
20 |
python-version: ${{ matrix.python-version }}
|
|
|
21 |
|
22 |
- name: Install dependencies
|
23 |
run: |
|
@@ -25,10 +28,19 @@ jobs:
|
|
25 |
pip install torch==2.2.0
|
26 |
bash scripts/install-pyg.sh
|
27 |
bash scripts/install-dgl.sh
|
28 |
-
pip install .[test]
|
29 |
pip install .[mace]
|
|
|
30 |
pip install "pynanoflann@git+https://github.com/dwastberg/pynanoflann#egg=af434039ae14bedcbb838a7808924d6689274168"
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
- name: Run tests
|
33 |
env:
|
34 |
PREFECT_API_KEY: ${{ secrets.PREFECT_API_KEY }}
|
|
|
7 |
runs-on: ubuntu-latest
|
8 |
|
9 |
strategy:
|
10 |
+
# max-parallel: 2
|
11 |
matrix:
|
12 |
python-version: ["3.10", "3.11"]
|
13 |
+
|
14 |
|
15 |
steps:
|
16 |
- name: Checkout repository
|
17 |
+
uses: actions/checkout@v4
|
18 |
|
19 |
- name: Set up Python ${{ matrix.python-version }}
|
20 |
+
uses: actions/setup-python@v5
|
21 |
with:
|
22 |
python-version: ${{ matrix.python-version }}
|
23 |
+
cache: 'pip'
|
24 |
|
25 |
- name: Install dependencies
|
26 |
run: |
|
|
|
28 |
pip install torch==2.2.0
|
29 |
bash scripts/install-pyg.sh
|
30 |
bash scripts/install-dgl.sh
|
|
|
31 |
pip install .[mace]
|
32 |
+
pip install .[test]
|
33 |
pip install "pynanoflann@git+https://github.com/dwastberg/pynanoflann#egg=af434039ae14bedcbb838a7808924d6689274168"
|
34 |
|
35 |
+
- name: List dependencies
|
36 |
+
run: pip list
|
37 |
+
|
38 |
+
- name: Login huggingface
|
39 |
+
env:
|
40 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN_READ_ONLY }}
|
41 |
+
run:
|
42 |
+
huggingface-cli login --token $HF_TOKEN
|
43 |
+
|
44 |
- name: Run tests
|
45 |
env:
|
46 |
PREFECT_API_KEY: ${{ secrets.PREFECT_API_KEY }}
|
mlip_arena/models/__init__.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
import torch
|
@@ -14,6 +16,17 @@ from torch import nn
|
|
14 |
with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
|
15 |
REGISTRY = yaml.safe_load(f)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
class MLIP(
|
19 |
nn.Module,
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import importlib
|
4 |
+
from enum import Enum
|
5 |
from pathlib import Path
|
6 |
|
7 |
import torch
|
|
|
16 |
with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
|
17 |
REGISTRY = yaml.safe_load(f)
|
18 |
|
19 |
+
MLIPMap = {}
|
20 |
+
|
21 |
+
for model, metadata in REGISTRY.items():
|
22 |
+
try:
|
23 |
+
module = importlib.import_module(f"{__package__}.{metadata['module']}.{metadata['family']}")
|
24 |
+
MLIPMap[model] = getattr(module, metadata["class"])
|
25 |
+
except ModuleNotFoundError as e:
|
26 |
+
print(e)
|
27 |
+
continue
|
28 |
+
|
29 |
+
MLIPEnum = Enum("MLIPEnum", MLIPMap)
|
30 |
|
31 |
class MLIP(
|
32 |
nn.Module,
|
mlip_arena/models/externals/__init__.py
ADDED
File without changes
|
mlip_arena/models/externals/alignn.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from alignn.ff.ff import AlignnAtomwiseCalculator, default_path
|
4 |
+
|
5 |
+
from mlip_arena.models.utils import get_freer_device
|
6 |
+
|
7 |
+
|
8 |
+
class ALIGNN(AlignnAtomwiseCalculator):
|
9 |
+
def __init__(self, device=None, **kwargs) -> None:
|
10 |
+
# TODO: cannot control version
|
11 |
+
# _ = get_figshare_model_ff(dir_path=dir_path)
|
12 |
+
model_path = default_path()
|
13 |
+
|
14 |
+
device = device or get_freer_device()
|
15 |
+
super().__init__(path=model_path, device=device, **kwargs)
|
mlip_arena/models/externals/chgnet.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
+
from ase import Atoms
|
6 |
+
from chgnet.model.dynamics import CHGNetCalculator
|
7 |
+
from chgnet.model.model import CHGNet as CHGNetModel
|
8 |
+
|
9 |
+
from mlip_arena.models.utils import get_freer_device
|
10 |
+
|
11 |
+
|
12 |
+
class CHGNet(CHGNetCalculator):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
checkpoint: CHGNetModel | None = None, # TODO: specifiy version
|
16 |
+
device: str | None = None,
|
17 |
+
stress_weight: float | None = 1 / 160.21766208,
|
18 |
+
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
|
19 |
+
**kwargs,
|
20 |
+
) -> None:
|
21 |
+
use_device = device or str(get_freer_device())
|
22 |
+
super().__init__(
|
23 |
+
model=checkpoint,
|
24 |
+
use_device=use_device,
|
25 |
+
stress_weight=stress_weight,
|
26 |
+
on_isolated_atoms=on_isolated_atoms,
|
27 |
+
**kwargs,
|
28 |
+
)
|
29 |
+
|
30 |
+
def calculate(
|
31 |
+
self,
|
32 |
+
atoms: Atoms | None = None,
|
33 |
+
properties: list | None = None,
|
34 |
+
system_changes: list | None = None,
|
35 |
+
) -> None:
|
36 |
+
super().calculate(atoms, properties, system_changes)
|
37 |
+
|
38 |
+
# for ase.io.write compatibility
|
39 |
+
self.results.pop("crystal_fea", None)
|
mlip_arena/models/externals/equiformer.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import yaml
|
6 |
+
from ase import Atoms
|
7 |
+
from fairchem.core import OCPCalculator
|
8 |
+
|
9 |
+
with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
|
10 |
+
REGISTRY = yaml.safe_load(f)
|
11 |
+
|
12 |
+
|
13 |
+
class EquiformerV2(OCPCalculator):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"],
|
17 |
+
# TODO: cannot assign device
|
18 |
+
local_cache="/tmp/ocp/",
|
19 |
+
cpu=False,
|
20 |
+
seed=0,
|
21 |
+
**kwargs,
|
22 |
+
) -> None:
|
23 |
+
super().__init__(
|
24 |
+
model_name=checkpoint,
|
25 |
+
local_cache=local_cache,
|
26 |
+
cpu=cpu,
|
27 |
+
seed=seed,
|
28 |
+
**kwargs,
|
29 |
+
)
|
30 |
+
|
31 |
+
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
|
32 |
+
super().calculate(atoms, properties, system_changes)
|
33 |
+
|
34 |
+
self.results.update(
|
35 |
+
force=atoms.get_forces(),
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
class EquiformerV2OC20(OCPCalculator):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"],
|
43 |
+
# TODO: cannot assign device
|
44 |
+
local_cache="/tmp/ocp/",
|
45 |
+
cpu=False,
|
46 |
+
seed=0,
|
47 |
+
**kwargs,
|
48 |
+
) -> None:
|
49 |
+
super().__init__(
|
50 |
+
model_name=checkpoint,
|
51 |
+
local_cache=local_cache,
|
52 |
+
cpu=cpu,
|
53 |
+
seed=seed,
|
54 |
+
**kwargs,
|
55 |
+
)
|
mlip_arena/models/externals/escn.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import yaml
|
6 |
+
from ase import Atoms
|
7 |
+
from fairchem.core import OCPCalculator
|
8 |
+
|
9 |
+
with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
|
10 |
+
REGISTRY = yaml.safe_load(f)
|
11 |
+
|
12 |
+
class eSCN(OCPCalculator):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
checkpoint=REGISTRY["eSCN(OC20)"]["checkpoint"], # "eSCN-L6-M3-Lay20-S2EF-OC20-All+MD"
|
16 |
+
# TODO: cannot assign device
|
17 |
+
local_cache="/tmp/ocp/",
|
18 |
+
cpu=False,
|
19 |
+
seed=0,
|
20 |
+
**kwargs,
|
21 |
+
) -> None:
|
22 |
+
super().__init__(
|
23 |
+
model_name=checkpoint,
|
24 |
+
local_cache=local_cache,
|
25 |
+
cpu=cpu,
|
26 |
+
seed=seed,
|
27 |
+
**kwargs,
|
28 |
+
)
|
29 |
+
|
30 |
+
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
|
31 |
+
super().calculate(atoms, properties, system_changes)
|
32 |
+
|
33 |
+
self.results.update(
|
34 |
+
force=atoms.get_forces(),
|
35 |
+
)
|
mlip_arena/models/externals/fairchem.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import yaml
|
6 |
+
from ase import Atoms
|
7 |
+
from fairchem.core import OCPCalculator
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
|
10 |
+
with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f:
|
11 |
+
REGISTRY = yaml.safe_load(f)
|
12 |
+
|
13 |
+
class eqV2(OCPCalculator):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
checkpoint=REGISTRY["eqV2(OMat)"]["checkpoint"],
|
17 |
+
cache_dir=None,
|
18 |
+
cpu=False, # TODO: cannot assign device
|
19 |
+
seed=0,
|
20 |
+
**kwargs,
|
21 |
+
) -> None:
|
22 |
+
"""
|
23 |
+
Initialize an eqV2 calculator.
|
24 |
+
|
25 |
+
Parameters
|
26 |
+
----------
|
27 |
+
checkpoint : str, default="eqV2_86M_omat_mp_salex.pt"
|
28 |
+
The name of the eqV2 checkpoint to use.
|
29 |
+
local_cache : str, default="/tmp/ocp/"
|
30 |
+
The directory to store the downloaded checkpoint.
|
31 |
+
cpu : bool, default=False
|
32 |
+
Whether to run the model on CPU or GPU.
|
33 |
+
seed : int, default=0
|
34 |
+
The random seed for the model.
|
35 |
+
|
36 |
+
Other Parameters
|
37 |
+
----------------
|
38 |
+
**kwargs
|
39 |
+
Any additional keyword arguments are passed to the superclass.
|
40 |
+
"""
|
41 |
+
|
42 |
+
# https://huggingface.co/fairchem/OMAT24/resolve/main/eqV2_86M_omat_mp_salex.pt
|
43 |
+
|
44 |
+
checkpoint_path = hf_hub_download(
|
45 |
+
"fairchem/OMAT24",
|
46 |
+
filename=checkpoint,
|
47 |
+
revision="bf92f9671cb9d5b5c77ecb4aa8b317ff10b882ce",
|
48 |
+
cache_dir=cache_dir
|
49 |
+
)
|
50 |
+
super().__init__(
|
51 |
+
checkpoint_path=checkpoint_path,
|
52 |
+
cpu=cpu,
|
53 |
+
seed=seed,
|
54 |
+
**kwargs,
|
55 |
+
)
|
56 |
+
|
57 |
+
class EquiformerV2(OCPCalculator):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"],
|
61 |
+
# TODO: cannot assign device
|
62 |
+
local_cache="/tmp/ocp/",
|
63 |
+
cpu=False,
|
64 |
+
seed=0,
|
65 |
+
**kwargs,
|
66 |
+
) -> None:
|
67 |
+
super().__init__(
|
68 |
+
model_name=checkpoint,
|
69 |
+
local_cache=local_cache,
|
70 |
+
cpu=cpu,
|
71 |
+
seed=seed,
|
72 |
+
**kwargs,
|
73 |
+
)
|
74 |
+
|
75 |
+
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
|
76 |
+
super().calculate(atoms, properties, system_changes)
|
77 |
+
|
78 |
+
self.results.update(
|
79 |
+
force=atoms.get_forces(),
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class EquiformerV2OC20(OCPCalculator):
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"],
|
87 |
+
# TODO: cannot assign device
|
88 |
+
local_cache="/tmp/ocp/",
|
89 |
+
cpu=False,
|
90 |
+
seed=0,
|
91 |
+
**kwargs,
|
92 |
+
) -> None:
|
93 |
+
super().__init__(
|
94 |
+
model_name=checkpoint,
|
95 |
+
local_cache=local_cache,
|
96 |
+
cpu=cpu,
|
97 |
+
seed=seed,
|
98 |
+
**kwargs,
|
99 |
+
)
|
100 |
+
|
101 |
+
class eSCN(OCPCalculator):
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
checkpoint="eSCN-L6-M3-Lay20-S2EF-OC20-All+MD", # TODO: import from registry
|
105 |
+
# TODO: cannot assign device
|
106 |
+
local_cache="/tmp/ocp/",
|
107 |
+
cpu=False,
|
108 |
+
seed=0,
|
109 |
+
**kwargs,
|
110 |
+
) -> None:
|
111 |
+
super().__init__(
|
112 |
+
model_name=checkpoint,
|
113 |
+
local_cache=local_cache,
|
114 |
+
cpu=cpu,
|
115 |
+
seed=seed,
|
116 |
+
**kwargs,
|
117 |
+
)
|
118 |
+
|
119 |
+
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
|
120 |
+
super().calculate(atoms, properties, system_changes)
|
121 |
+
|
122 |
+
self.results.update(
|
123 |
+
force=atoms.get_forces(),
|
124 |
+
)
|
mlip_arena/models/externals/mace-mp.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from mace.calculators import MACECalculator
|
7 |
+
|
8 |
+
from mlip_arena.models.utils import get_freer_device
|
9 |
+
|
10 |
+
|
11 |
+
class MACE_MP_Medium(MACECalculator):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
checkpoint="http://tinyurl.com/5yyxdm76",
|
15 |
+
device: str | None = None,
|
16 |
+
default_dtype="float32",
|
17 |
+
**kwargs,
|
18 |
+
):
|
19 |
+
cache_dir = Path.home() / ".cache" / "mace"
|
20 |
+
checkpoint_url_name = "".join(
|
21 |
+
c for c in os.path.basename(checkpoint) if c.isalnum() or c in "_"
|
22 |
+
)
|
23 |
+
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
|
24 |
+
if not os.path.isfile(cached_model_path):
|
25 |
+
import urllib
|
26 |
+
|
27 |
+
os.makedirs(cache_dir, exist_ok=True)
|
28 |
+
_, http_msg = urllib.request.urlretrieve(checkpoint, cached_model_path)
|
29 |
+
if "Content-Type: text/html" in http_msg:
|
30 |
+
raise RuntimeError(
|
31 |
+
f"Model download failed, please check the URL {checkpoint}"
|
32 |
+
)
|
33 |
+
model = cached_model_path
|
34 |
+
|
35 |
+
device = device or str(get_freer_device())
|
36 |
+
|
37 |
+
super().__init__(
|
38 |
+
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
|
39 |
+
)
|
mlip_arena/models/externals/mace-off.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from mace.calculators import MACECalculator
|
7 |
+
|
8 |
+
from mlip_arena.models.utils import get_freer_device
|
9 |
+
|
10 |
+
|
11 |
+
class MACE_OFF_Medium(MACECalculator):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
checkpoint="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
|
15 |
+
device: str | None = None,
|
16 |
+
default_dtype="float32",
|
17 |
+
**kwargs,
|
18 |
+
):
|
19 |
+
cache_dir = Path.home() / ".cache" / "mace"
|
20 |
+
checkpoint_url_name = "".join(
|
21 |
+
c for c in os.path.basename(checkpoint) if c.isalnum() or c in "_"
|
22 |
+
)
|
23 |
+
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
|
24 |
+
if not os.path.isfile(cached_model_path):
|
25 |
+
import urllib
|
26 |
+
|
27 |
+
os.makedirs(cache_dir, exist_ok=True)
|
28 |
+
_, http_msg = urllib.request.urlretrieve(checkpoint, cached_model_path)
|
29 |
+
if "Content-Type: text/html" in http_msg:
|
30 |
+
raise RuntimeError(
|
31 |
+
f"Model download failed, please check the URL {checkpoint}"
|
32 |
+
)
|
33 |
+
model = cached_model_path
|
34 |
+
|
35 |
+
device = device or str(get_freer_device())
|
36 |
+
|
37 |
+
super().__init__(
|
38 |
+
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
|
39 |
+
)
|
mlip_arena/models/externals/matgl.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import matgl
|
4 |
+
import torch
|
5 |
+
from matgl.ext.ase import PESCalculator
|
6 |
+
|
7 |
+
|
8 |
+
class M3GNet(PESCalculator):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
checkpoint="M3GNet-MP-2021.2.8-PES",
|
12 |
+
# TODO: cannot assign device
|
13 |
+
state_attr: torch.Tensor | None = None,
|
14 |
+
stress_weight: float = 1.0,
|
15 |
+
**kwargs,
|
16 |
+
) -> None:
|
17 |
+
potential = matgl.load_model(checkpoint)
|
18 |
+
super().__init__(potential, state_attr, stress_weight, **kwargs)
|
mlip_arena/models/externals/orb.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import requests
|
6 |
+
from orb_models.forcefield import pretrained
|
7 |
+
from orb_models.forcefield.calculator import ORBCalculator
|
8 |
+
|
9 |
+
from mlip_arena.models.utils import get_freer_device
|
10 |
+
|
11 |
+
|
12 |
+
class ORB(ORBCalculator):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
checkpoint="orbff-v1-20240827.ckpt",
|
16 |
+
device=None,
|
17 |
+
**kwargs,
|
18 |
+
):
|
19 |
+
device = device or get_freer_device()
|
20 |
+
|
21 |
+
cache_dir = Path.home() / ".cache" / "orb"
|
22 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
23 |
+
ckpt_path = cache_dir / "orbff-v1-20240827.ckpt"
|
24 |
+
|
25 |
+
url = f"https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/{checkpoint}"
|
26 |
+
|
27 |
+
if not ckpt_path.exists():
|
28 |
+
print(f"Downloading ORB model from {url} to {ckpt_path}...")
|
29 |
+
try:
|
30 |
+
response = requests.get(url, stream=True, timeout=120)
|
31 |
+
response.raise_for_status()
|
32 |
+
with open(ckpt_path, "wb") as f:
|
33 |
+
for chunk in response.iter_content(chunk_size=8192):
|
34 |
+
f.write(chunk)
|
35 |
+
print("Download completed.")
|
36 |
+
except requests.exceptions.RequestException as e:
|
37 |
+
raise RuntimeError("Failed to download ORB model.") from e
|
38 |
+
|
39 |
+
orbff = pretrained.orb_v1(weights_path=ckpt_path, device=device)
|
40 |
+
super().__init__(orbff, device=device, **kwargs)
|
mlip_arena/models/externals/sevennet.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from sevenn.sevennet_calculator import SevenNetCalculator
|
4 |
+
|
5 |
+
from mlip_arena.models.utils import get_freer_device
|
6 |
+
|
7 |
+
|
8 |
+
class SevenNet(SevenNetCalculator):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
checkpoint="7net-0", # TODO: import from registry
|
12 |
+
device=None,
|
13 |
+
**kwargs,
|
14 |
+
):
|
15 |
+
device = device or get_freer_device()
|
16 |
+
super().__init__(checkpoint, device=device, **kwargs)
|
mlip_arena/models/registry.yaml
CHANGED
@@ -8,7 +8,7 @@ MACE-MP(M):
|
|
8 |
last-update: 2024-03-25T14:30:00
|
9 |
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
|
10 |
datasets:
|
11 |
-
-
|
12 |
cpu-tasks:
|
13 |
- alexandria
|
14 |
- qmof
|
@@ -33,7 +33,7 @@ CHGNet:
|
|
33 |
last-update: 2024-07-08T00:00:00
|
34 |
datetime: 2024-07-08T00:00:00
|
35 |
datasets:
|
36 |
-
-
|
37 |
gpu-tasks:
|
38 |
- homonuclear-diatomics
|
39 |
- stability
|
@@ -48,14 +48,14 @@ CHGNet:
|
|
48 |
M3GNet:
|
49 |
module: externals
|
50 |
class: M3GNet
|
51 |
-
family:
|
52 |
package: matgl==1.1.2
|
53 |
checkpoint:
|
54 |
username: cyrusyc
|
55 |
last-update: 2024-07-08T00:00:00
|
56 |
datetime: 2024-07-08T00:00:00
|
57 |
datasets:
|
58 |
-
-
|
59 |
gpu-tasks:
|
60 |
- homonuclear-diatomics
|
61 |
- combustion
|
@@ -76,8 +76,8 @@ ORB:
|
|
76 |
last-update: 2024-03-25T14:30:00
|
77 |
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
|
78 |
datasets:
|
79 |
-
-
|
80 |
-
-
|
81 |
cpu-tasks:
|
82 |
- alexandria
|
83 |
- qmof
|
@@ -102,7 +102,7 @@ SevenNet:
|
|
102 |
last-update: 2024-03-25T14:30:00
|
103 |
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
|
104 |
datasets:
|
105 |
-
-
|
106 |
cpu-tasks:
|
107 |
- alexandria
|
108 |
- qmof
|
@@ -117,17 +117,39 @@ SevenNet:
|
|
117 |
nvt: true
|
118 |
npt: true
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
EquiformerV2(OC22):
|
121 |
module: externals
|
122 |
class: EquiformerV2
|
123 |
family: equiformer
|
124 |
-
package: fairchem-core==1.
|
125 |
checkpoint: EquiformerV2-lE4-lF100-S2EFS-OC22
|
126 |
username: cyrusyc
|
127 |
last-update: 2024-07-08T00:00:00
|
128 |
datetime: 2024-07-08T00:00:00
|
129 |
datasets:
|
130 |
-
-
|
131 |
gpu-tasks:
|
132 |
- homonuclear-diatomics
|
133 |
- combustion
|
@@ -142,13 +164,13 @@ EquiformerV2(OC20):
|
|
142 |
module: externals
|
143 |
class: EquiformerV2OC20
|
144 |
family: equiformer
|
145 |
-
package: fairchem-core==1.
|
146 |
checkpoint: EquiformerV2-31M-S2EF-OC20-All+MD
|
147 |
username: cyrusyc
|
148 |
last-update: 2024-07-08T00:00:00
|
149 |
datetime: 2024-07-08T00:00:00
|
150 |
datasets:
|
151 |
-
-
|
152 |
gpu-tasks:
|
153 |
- homonuclear-diatomics
|
154 |
- combustion
|
@@ -163,13 +185,13 @@ eSCN(OC20):
|
|
163 |
module: externals
|
164 |
class: eSCN
|
165 |
family: escn
|
166 |
-
package: fairchem-core==1.
|
167 |
checkpoint: eSCN-L6-M3-Lay20-S2EF-OC20-All+MD
|
168 |
username: cyrusyc
|
169 |
last-update: 2024-07-08T00:00:00
|
170 |
datetime: 2024-07-08T00:00:00
|
171 |
datasets:
|
172 |
-
-
|
173 |
gpu-tasks:
|
174 |
- homonuclear-diatomics
|
175 |
- combustion
|
@@ -188,7 +210,7 @@ MACE-OFF(M):
|
|
188 |
last-update: 2024-03-25T14:30:00
|
189 |
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
|
190 |
datasets:
|
191 |
-
-
|
192 |
cpu-tasks:
|
193 |
- alexandria
|
194 |
- qmof
|
@@ -211,7 +233,7 @@ ALIGNN:
|
|
211 |
last-update: 2024-07-08T00:00:00
|
212 |
datetime: 2024-07-08T00:00:00
|
213 |
datasets:
|
214 |
-
-
|
215 |
gpu-tasks:
|
216 |
- homonuclear-diatomics
|
217 |
# - combustion
|
|
|
8 |
last-update: 2024-03-25T14:30:00
|
9 |
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
|
10 |
datasets:
|
11 |
+
- MPTrj # TODO: fake HF dataset repo
|
12 |
cpu-tasks:
|
13 |
- alexandria
|
14 |
- qmof
|
|
|
33 |
last-update: 2024-07-08T00:00:00
|
34 |
datetime: 2024-07-08T00:00:00
|
35 |
datasets:
|
36 |
+
- MPTrj
|
37 |
gpu-tasks:
|
38 |
- homonuclear-diatomics
|
39 |
- stability
|
|
|
48 |
M3GNet:
|
49 |
module: externals
|
50 |
class: M3GNet
|
51 |
+
family: matgl
|
52 |
package: matgl==1.1.2
|
53 |
checkpoint:
|
54 |
username: cyrusyc
|
55 |
last-update: 2024-07-08T00:00:00
|
56 |
datetime: 2024-07-08T00:00:00
|
57 |
datasets:
|
58 |
+
- MPF
|
59 |
gpu-tasks:
|
60 |
- homonuclear-diatomics
|
61 |
- combustion
|
|
|
76 |
last-update: 2024-03-25T14:30:00
|
77 |
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
|
78 |
datasets:
|
79 |
+
- MPTrj # TODO: fake HF dataset repo
|
80 |
+
- Alexandria
|
81 |
cpu-tasks:
|
82 |
- alexandria
|
83 |
- qmof
|
|
|
102 |
last-update: 2024-03-25T14:30:00
|
103 |
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
|
104 |
datasets:
|
105 |
+
- MPTrj # TODO: fake HF dataset repo
|
106 |
cpu-tasks:
|
107 |
- alexandria
|
108 |
- qmof
|
|
|
117 |
nvt: true
|
118 |
npt: true
|
119 |
|
120 |
+
eqV2(OMat):
|
121 |
+
module: externals
|
122 |
+
class: eqV2
|
123 |
+
family: fairchem
|
124 |
+
package: fairchem-core==1.1.0
|
125 |
+
checkpoint: eqV2_86M_omat_mp_salex.pt
|
126 |
+
username: fairchem # HF handle
|
127 |
+
last-update: 2024-10-18T00:00:00
|
128 |
+
datetime: 2024-10-18T00:00:00
|
129 |
+
datasets:
|
130 |
+
- OMat
|
131 |
+
- MPTrj
|
132 |
+
- Alexandria
|
133 |
+
gpu-tasks:
|
134 |
+
- homonuclear-diatomics
|
135 |
+
prediction: EFS
|
136 |
+
nvt: true
|
137 |
+
npt: true
|
138 |
+
github: https://github.com/FAIR-Chem/fairchem
|
139 |
+
doi: https://arxiv.org/abs/2410.12771
|
140 |
+
|
141 |
+
|
142 |
EquiformerV2(OC22):
|
143 |
module: externals
|
144 |
class: EquiformerV2
|
145 |
family: equiformer
|
146 |
+
package: fairchem-core==1.1.0
|
147 |
checkpoint: EquiformerV2-lE4-lF100-S2EFS-OC22
|
148 |
username: cyrusyc
|
149 |
last-update: 2024-07-08T00:00:00
|
150 |
datetime: 2024-07-08T00:00:00
|
151 |
datasets:
|
152 |
+
- OC22
|
153 |
gpu-tasks:
|
154 |
- homonuclear-diatomics
|
155 |
- combustion
|
|
|
164 |
module: externals
|
165 |
class: EquiformerV2OC20
|
166 |
family: equiformer
|
167 |
+
package: fairchem-core==1.1.0
|
168 |
checkpoint: EquiformerV2-31M-S2EF-OC20-All+MD
|
169 |
username: cyrusyc
|
170 |
last-update: 2024-07-08T00:00:00
|
171 |
datetime: 2024-07-08T00:00:00
|
172 |
datasets:
|
173 |
+
- OC20
|
174 |
gpu-tasks:
|
175 |
- homonuclear-diatomics
|
176 |
- combustion
|
|
|
185 |
module: externals
|
186 |
class: eSCN
|
187 |
family: escn
|
188 |
+
package: fairchem-core==1.1.0
|
189 |
checkpoint: eSCN-L6-M3-Lay20-S2EF-OC20-All+MD
|
190 |
username: cyrusyc
|
191 |
last-update: 2024-07-08T00:00:00
|
192 |
datetime: 2024-07-08T00:00:00
|
193 |
datasets:
|
194 |
+
- OC20
|
195 |
gpu-tasks:
|
196 |
- homonuclear-diatomics
|
197 |
- combustion
|
|
|
210 |
last-update: 2024-03-25T14:30:00
|
211 |
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
|
212 |
datasets:
|
213 |
+
- SPICE # TODO: fake HF dataset repo
|
214 |
cpu-tasks:
|
215 |
- alexandria
|
216 |
- qmof
|
|
|
233 |
last-update: 2024-07-08T00:00:00
|
234 |
datetime: 2024-07-08T00:00:00
|
235 |
datasets:
|
236 |
+
- MP22
|
237 |
gpu-tasks:
|
238 |
- homonuclear-diatomics
|
239 |
# - combustion
|
mlip_arena/models/utils.py
CHANGED
@@ -1,20 +1,7 @@
|
|
1 |
"""Utility functions for MLIP models."""
|
2 |
|
3 |
-
import importlib
|
4 |
-
from enum import Enum
|
5 |
-
|
6 |
import torch
|
7 |
|
8 |
-
from mlip_arena.models import REGISTRY
|
9 |
-
|
10 |
-
MLIPMap = {
|
11 |
-
model: getattr(
|
12 |
-
importlib.import_module(f"{__package__}.{metadata['module']}"), metadata["class"],
|
13 |
-
)
|
14 |
-
for model, metadata in REGISTRY.items()
|
15 |
-
}
|
16 |
-
MLIPEnum = Enum("MLIPEnum", MLIPMap)
|
17 |
-
|
18 |
|
19 |
def get_freer_device() -> torch.device:
|
20 |
"""Get the GPU with the most free memory, or use MPS if available.
|
@@ -47,4 +34,4 @@ def get_freer_device() -> torch.device:
|
|
47 |
print("No GPU or MPS available. Using CPU.")
|
48 |
device = torch.device("cpu")
|
49 |
|
50 |
-
return device
|
|
|
1 |
"""Utility functions for MLIP models."""
|
2 |
|
|
|
|
|
|
|
3 |
import torch
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
def get_freer_device() -> torch.device:
|
7 |
"""Get the GPU with the most free memory, or use MPS if available.
|
|
|
34 |
print("No GPU or MPS available. Using CPU.")
|
35 |
device = torch.device("cpu")
|
36 |
|
37 |
+
return device
|
mlip_arena/tasks/combustion/{m3gnet → matgl}/hydrogen.json
RENAMED
File without changes
|
mlip_arena/tasks/diatomics/fairchem/homonuclear-diatomics.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:16144d6e7ff5b05d805f37cef43c03ccbc1f27787242da1a12455f62708069a8
|
3 |
+
size 2132550
|
mlip_arena/tasks/diatomics/matgl/homonuclear-diatomics.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c257248d2f62bfd09eb334c49405287bf5bc3bf14cfc6ad0a5890425f559f91c
|
3 |
+
size 1854330
|
mlip_arena/tasks/diatomics/run.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
mlip_arena/tasks/eos/run.py
CHANGED
@@ -18,7 +18,7 @@ from prefect.futures import wait
|
|
18 |
from prefect.runtime import flow_run, task_run
|
19 |
from pymatgen.analysis.eos import BirchMurnaghan
|
20 |
|
21 |
-
from mlip_arena.models
|
22 |
from mlip_arena.tasks.optimize import run as OPT
|
23 |
|
24 |
if TYPE_CHECKING:
|
|
|
18 |
from prefect.runtime import flow_run, task_run
|
19 |
from pymatgen.analysis.eos import BirchMurnaghan
|
20 |
|
21 |
+
from mlip_arena.models import MLIPEnum
|
22 |
from mlip_arena.tasks.optimize import run as OPT
|
23 |
|
24 |
if TYPE_CHECKING:
|
mlip_arena/tasks/md.py
CHANGED
@@ -84,7 +84,8 @@ from scipy.linalg import schur
|
|
84 |
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
85 |
from tqdm.auto import tqdm
|
86 |
|
87 |
-
from mlip_arena.models
|
|
|
88 |
|
89 |
_valid_dynamics: dict[str, tuple[str, ...]] = {
|
90 |
"nve": ("velocityverlet",),
|
|
|
84 |
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
85 |
from tqdm.auto import tqdm
|
86 |
|
87 |
+
from mlip_arena.models import MLIPEnum
|
88 |
+
from mlip_arena.models.utils import get_freer_device
|
89 |
|
90 |
_valid_dynamics: dict[str, tuple[str, ...]] = {
|
91 |
"nve": ("velocityverlet",),
|
mlip_arena/tasks/optimize.py
CHANGED
@@ -17,7 +17,8 @@ from prefect import task
|
|
17 |
from prefect.tasks import task_input_hash
|
18 |
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
19 |
|
20 |
-
from mlip_arena.models
|
|
|
21 |
|
22 |
_valid_filters: dict[str, Filter] = {
|
23 |
"Filter": Filter,
|
|
|
17 |
from prefect.tasks import task_input_hash
|
18 |
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
19 |
|
20 |
+
from mlip_arena.models import MLIPEnum
|
21 |
+
from mlip_arena.models.utils import get_freer_device
|
22 |
|
23 |
_valid_filters: dict[str, Filter] = {
|
24 |
"Filter": Filter,
|
mlip_arena/tasks/run.py
DELETED
@@ -1,318 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
from datetime import datetime, timedelta
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Literal, Sequence, Tuple
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
from ase import Atoms, units
|
9 |
-
from ase.calculators.calculator import Calculator
|
10 |
-
from ase.calculators.mixing import SumCalculator
|
11 |
-
from ase.io import read
|
12 |
-
from ase.io.trajectory import Trajectory
|
13 |
-
from ase.md.andersen import Andersen
|
14 |
-
from ase.md.langevin import Langevin
|
15 |
-
from ase.md.md import MolecularDynamics
|
16 |
-
from ase.md.npt import NPT
|
17 |
-
from ase.md.nptberendsen import NPTBerendsen
|
18 |
-
from ase.md.nvtberendsen import NVTBerendsen
|
19 |
-
from ase.md.velocitydistribution import (
|
20 |
-
MaxwellBoltzmannDistribution,
|
21 |
-
Stationary,
|
22 |
-
ZeroRotation,
|
23 |
-
)
|
24 |
-
from ase.md.verlet import VelocityVerlet
|
25 |
-
from prefect import task
|
26 |
-
from prefect.tasks import task_input_hash
|
27 |
-
from scipy.interpolate import interp1d
|
28 |
-
from scipy.linalg import schur
|
29 |
-
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
30 |
-
from tqdm.auto import tqdm
|
31 |
-
|
32 |
-
from mlip_arena.models.utils import MLIPEnum, get_freer_device
|
33 |
-
|
34 |
-
# from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
|
35 |
-
|
36 |
-
_valid_dynamics: dict[str, tuple[str, ...]] = {
|
37 |
-
"nve": ("velocityverlet",),
|
38 |
-
"nvt": ("nose-hoover", "langevin", "andersen", "berendsen"),
|
39 |
-
"npt": ("nose-hoover", "berendsen"),
|
40 |
-
}
|
41 |
-
|
42 |
-
_preset_dynamics: dict = {
|
43 |
-
"nve_velocityverlet": VelocityVerlet,
|
44 |
-
"nvt_andersen": Andersen,
|
45 |
-
"nvt_berendsen": NVTBerendsen,
|
46 |
-
"nvt_langevin": Langevin,
|
47 |
-
"nvt_nose-hoover": NPT,
|
48 |
-
"npt_berendsen": NPTBerendsen,
|
49 |
-
"npt_nose-hoover": NPT,
|
50 |
-
}
|
51 |
-
|
52 |
-
|
53 |
-
def _interpolate_quantity(values: Sequence | np.ndarray, n_pts: int) -> np.ndarray:
|
54 |
-
"""Interpolate temperature / pressure on a schedule."""
|
55 |
-
n_vals = len(values)
|
56 |
-
return np.interp(
|
57 |
-
np.linspace(0, n_vals - 1, n_pts + 1),
|
58 |
-
np.linspace(0, n_vals - 1, n_vals),
|
59 |
-
values,
|
60 |
-
)
|
61 |
-
|
62 |
-
|
63 |
-
def _get_ensemble_schedule(
|
64 |
-
ensemble: Literal["nve", "nvt", "npt"] = "nvt",
|
65 |
-
n_steps: int = 1000,
|
66 |
-
temperature: float | Sequence | np.ndarray | None = 300.0,
|
67 |
-
pressure: float | Sequence | np.ndarray | None = None,
|
68 |
-
) -> Tuple[np.ndarray, np.ndarray]:
|
69 |
-
if ensemble == "nve":
|
70 |
-
# Disable thermostat and barostat
|
71 |
-
temperature = np.nan
|
72 |
-
pressure = np.nan
|
73 |
-
t_schedule = np.full(n_steps + 1, temperature)
|
74 |
-
p_schedule = np.full(n_steps + 1, pressure)
|
75 |
-
return t_schedule, p_schedule
|
76 |
-
|
77 |
-
if isinstance(temperature, Sequence) or (
|
78 |
-
isinstance(temperature, np.ndarray) and temperature.ndim == 1
|
79 |
-
):
|
80 |
-
t_schedule = _interpolate_quantity(temperature, n_steps)
|
81 |
-
# NOTE: In ASE Langevin dynamics, the temperature are normally
|
82 |
-
# scalars, but in principle one quantity per atom could be specified by giving
|
83 |
-
# an array. This is not implemented yet here.
|
84 |
-
else:
|
85 |
-
t_schedule = np.full(n_steps + 1, temperature)
|
86 |
-
|
87 |
-
if ensemble == "nvt":
|
88 |
-
pressure = np.nan
|
89 |
-
p_schedule = np.full(n_steps + 1, pressure)
|
90 |
-
return t_schedule, p_schedule
|
91 |
-
|
92 |
-
if isinstance(pressure, Sequence) or (
|
93 |
-
isinstance(pressure, np.ndarray) and pressure.ndim == 1
|
94 |
-
):
|
95 |
-
p_schedule = _interpolate_quantity(pressure, n_steps)
|
96 |
-
elif isinstance(pressure, np.ndarray) and pressure.ndim == 4:
|
97 |
-
p_schedule = interp1d(np.arange(n_steps + 1), pressure, kind="linear")
|
98 |
-
assert isinstance(p_schedule, np.ndarray)
|
99 |
-
else:
|
100 |
-
p_schedule = np.full(n_steps + 1, pressure)
|
101 |
-
|
102 |
-
return t_schedule, p_schedule
|
103 |
-
|
104 |
-
|
105 |
-
def _get_ensemble_defaults(
|
106 |
-
ensemble: Literal["nve", "nvt", "npt"],
|
107 |
-
dynamics: str | MolecularDynamics,
|
108 |
-
t_schedule: np.ndarray,
|
109 |
-
p_schedule: np.ndarray,
|
110 |
-
ase_md_kwargs: dict | None = None,
|
111 |
-
) -> dict:
|
112 |
-
"""Update ASE MD kwargs"""
|
113 |
-
ase_md_kwargs = ase_md_kwargs or {}
|
114 |
-
|
115 |
-
if ensemble == "nve":
|
116 |
-
ase_md_kwargs.pop("temperature", None)
|
117 |
-
ase_md_kwargs.pop("temperature_K", None)
|
118 |
-
ase_md_kwargs.pop("externalstress", None)
|
119 |
-
elif ensemble == "nvt":
|
120 |
-
ase_md_kwargs["temperature_K"] = t_schedule[0]
|
121 |
-
ase_md_kwargs.pop("externalstress", None)
|
122 |
-
elif ensemble == "npt":
|
123 |
-
ase_md_kwargs["temperature_K"] = t_schedule[0]
|
124 |
-
ase_md_kwargs["externalstress"] = p_schedule[0] # * 1e3 * units.bar
|
125 |
-
|
126 |
-
if isinstance(dynamics, str) and dynamics.lower() == "langevin":
|
127 |
-
ase_md_kwargs["friction"] = ase_md_kwargs.get(
|
128 |
-
"friction",
|
129 |
-
10.0 * 1e-3 / units.fs, # Same default as in VASP: 10 ps^-1
|
130 |
-
)
|
131 |
-
|
132 |
-
return ase_md_kwargs
|
133 |
-
|
134 |
-
|
135 |
-
@task(cache_key_fn=task_input_hash, cache_expiration=timedelta(days=1))
|
136 |
-
def md(
|
137 |
-
atoms: Atoms,
|
138 |
-
calculator_name: str | MLIPEnum,
|
139 |
-
calculator_kwargs: dict | None,
|
140 |
-
dispersion: str | None = None,
|
141 |
-
dispersion_kwargs: dict | None = None,
|
142 |
-
device: str | None = None,
|
143 |
-
ensemble: Literal["nve", "nvt", "npt"] = "nvt",
|
144 |
-
dynamics: str | MolecularDynamics = "langevin",
|
145 |
-
time_step: float | None = None,
|
146 |
-
total_time: float = 1000,
|
147 |
-
temperature: float | Sequence | np.ndarray | None = 300.0,
|
148 |
-
pressure: float | Sequence | np.ndarray | None = None,
|
149 |
-
ase_md_kwargs: dict | None = None,
|
150 |
-
md_velocity_seed: int | None = None,
|
151 |
-
zero_linear_momentum: bool = True,
|
152 |
-
zero_angular_momentum: bool = True,
|
153 |
-
traj_file: str | Path | None = None,
|
154 |
-
traj_interval: int = 1,
|
155 |
-
restart: bool = True,
|
156 |
-
):
|
157 |
-
device = device or str(get_freer_device())
|
158 |
-
|
159 |
-
print(f"Using device: {device}")
|
160 |
-
|
161 |
-
calculator_kwargs = calculator_kwargs or {}
|
162 |
-
|
163 |
-
if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
|
164 |
-
assert issubclass(calculator_name.value, Calculator)
|
165 |
-
calc = calculator_name.value(**calculator_kwargs)
|
166 |
-
elif (
|
167 |
-
isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_
|
168 |
-
):
|
169 |
-
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
|
170 |
-
else:
|
171 |
-
raise ValueError(f"Invalid calculator: {calculator_name}")
|
172 |
-
|
173 |
-
print(f"Using calculator: {calc}")
|
174 |
-
|
175 |
-
dispersion_kwargs = dispersion_kwargs or {}
|
176 |
-
|
177 |
-
dispersion_kwargs.update({"device": device})
|
178 |
-
|
179 |
-
if dispersion is not None:
|
180 |
-
disp_calc = TorchDFTD3Calculator(
|
181 |
-
**dispersion_kwargs,
|
182 |
-
)
|
183 |
-
calc = SumCalculator([calc, disp_calc])
|
184 |
-
|
185 |
-
print(f"Using dispersion: {dispersion}")
|
186 |
-
|
187 |
-
atoms.calc = calc
|
188 |
-
|
189 |
-
if time_step is None:
|
190 |
-
# If a structure contains an isotope of hydrogen, set default `time_step`
|
191 |
-
# to 0.5 fs, and 2 fs otherwise.
|
192 |
-
has_h_isotope = "H" in atoms.get_chemical_symbols()
|
193 |
-
time_step = 0.5 if has_h_isotope else 2.0
|
194 |
-
|
195 |
-
n_steps = int(total_time / time_step)
|
196 |
-
target_steps = n_steps
|
197 |
-
|
198 |
-
t_schedule, p_schedule = _get_ensemble_schedule(
|
199 |
-
ensemble=ensemble,
|
200 |
-
n_steps=n_steps,
|
201 |
-
temperature=temperature,
|
202 |
-
pressure=pressure,
|
203 |
-
)
|
204 |
-
|
205 |
-
ase_md_kwargs = _get_ensemble_defaults(
|
206 |
-
ensemble=ensemble,
|
207 |
-
dynamics=dynamics,
|
208 |
-
t_schedule=t_schedule,
|
209 |
-
p_schedule=p_schedule,
|
210 |
-
ase_md_kwargs=ase_md_kwargs,
|
211 |
-
)
|
212 |
-
|
213 |
-
if isinstance(dynamics, str):
|
214 |
-
# Use known dynamics if `self.dynamics` is a str
|
215 |
-
dynamics = dynamics.lower()
|
216 |
-
if dynamics not in _valid_dynamics[ensemble]:
|
217 |
-
raise ValueError(
|
218 |
-
f"{dynamics} thermostat not available for {ensemble}."
|
219 |
-
f"Available {ensemble} thermostats are:"
|
220 |
-
" ".join(_valid_dynamics[ensemble])
|
221 |
-
)
|
222 |
-
if ensemble == "nve":
|
223 |
-
dynamics = "velocityverlet"
|
224 |
-
md_class = _preset_dynamics[f"{ensemble}_{dynamics}"]
|
225 |
-
elif dynamics is MolecularDynamics:
|
226 |
-
md_class = dynamics
|
227 |
-
else:
|
228 |
-
raise ValueError(f"Invalid dynamics: {dynamics}")
|
229 |
-
|
230 |
-
if md_class is NPT:
|
231 |
-
# Note that until md_func is instantiated, isinstance(md_func,NPT) is False
|
232 |
-
# ASE NPT implementation requires upper triangular cell
|
233 |
-
u, _ = schur(atoms.get_cell(complete=True), output="complex")
|
234 |
-
atoms.set_cell(u.real, scale_atoms=True)
|
235 |
-
|
236 |
-
last_step = 0
|
237 |
-
|
238 |
-
if traj_file is not None:
|
239 |
-
traj_file = Path(traj_file)
|
240 |
-
traj_file.parent.mkdir(parents=True, exist_ok=True)
|
241 |
-
|
242 |
-
if restart and traj_file.exists():
|
243 |
-
try:
|
244 |
-
traj = read(traj_file, index=":")
|
245 |
-
last_atoms = traj[-1]
|
246 |
-
assert isinstance(last_atoms, Atoms)
|
247 |
-
last_step = last_atoms.info.get("step", len(traj) * traj_interval)
|
248 |
-
n_steps -= last_step
|
249 |
-
traj = Trajectory(traj_file, "a", atoms)
|
250 |
-
atoms.set_positions(last_atoms.get_positions())
|
251 |
-
atoms.set_momenta(last_atoms.get_momenta())
|
252 |
-
except Exception:
|
253 |
-
traj = Trajectory(traj_file, "w", atoms)
|
254 |
-
|
255 |
-
if not np.isnan(t_schedule).any():
|
256 |
-
MaxwellBoltzmannDistribution(
|
257 |
-
atoms=atoms,
|
258 |
-
temperature_K=t_schedule[last_step],
|
259 |
-
rng=np.random.default_rng(seed=md_velocity_seed),
|
260 |
-
)
|
261 |
-
|
262 |
-
if zero_linear_momentum:
|
263 |
-
Stationary(atoms)
|
264 |
-
if zero_angular_momentum:
|
265 |
-
ZeroRotation(atoms)
|
266 |
-
else:
|
267 |
-
traj = Trajectory(traj_file, "w", atoms)
|
268 |
-
|
269 |
-
if not np.isnan(t_schedule).any():
|
270 |
-
MaxwellBoltzmannDistribution(
|
271 |
-
atoms=atoms,
|
272 |
-
temperature_K=t_schedule[last_step],
|
273 |
-
rng=np.random.default_rng(seed=md_velocity_seed),
|
274 |
-
)
|
275 |
-
|
276 |
-
if zero_linear_momentum:
|
277 |
-
Stationary(atoms)
|
278 |
-
if zero_angular_momentum:
|
279 |
-
ZeroRotation(atoms)
|
280 |
-
|
281 |
-
md_runner = md_class(
|
282 |
-
atoms=atoms,
|
283 |
-
timestep=time_step * units.fs,
|
284 |
-
**ase_md_kwargs,
|
285 |
-
)
|
286 |
-
|
287 |
-
if traj_file is not None:
|
288 |
-
md_runner.attach(traj.write, interval=traj_interval)
|
289 |
-
|
290 |
-
with tqdm(total=n_steps) as pbar:
|
291 |
-
|
292 |
-
def _callback(dyn: MolecularDynamics = md_runner) -> None:
|
293 |
-
step = last_step + dyn.nsteps
|
294 |
-
dyn.atoms.info["restart"] = last_step
|
295 |
-
dyn.atoms.info["datetime"] = datetime.now()
|
296 |
-
dyn.atoms.info["step"] = step
|
297 |
-
dyn.atoms.info["target_steps"] = target_steps
|
298 |
-
if ensemble == "nve":
|
299 |
-
return
|
300 |
-
dyn.set_temperature(temperature_K=t_schedule[step])
|
301 |
-
if ensemble == "nvt":
|
302 |
-
return
|
303 |
-
dyn.set_stress(p_schedule[step] * 1e3 * units.bar)
|
304 |
-
pbar.update()
|
305 |
-
|
306 |
-
md_runner.attach(_callback, interval=1)
|
307 |
-
|
308 |
-
start_time = datetime.now()
|
309 |
-
md_runner.run(steps=n_steps)
|
310 |
-
end_time = datetime.now()
|
311 |
-
|
312 |
-
traj.close()
|
313 |
-
|
314 |
-
return {
|
315 |
-
"atoms": atoms,
|
316 |
-
"runtime": end_time - start_time,
|
317 |
-
"n_steps": n_steps,
|
318 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
CHANGED
@@ -44,11 +44,12 @@ run = [
|
|
44 |
"dgl==2.4.0",
|
45 |
"mace-torch==0.3.4",
|
46 |
"chgnet==0.3.8",
|
47 |
-
"fairchem-core==1.
|
48 |
"sevenn==0.9.3.post1",
|
49 |
"orb-models==0.3.1",
|
50 |
"alignn==2024.5.27",
|
51 |
-
"prefect>=3.0.4"
|
|
|
52 |
]
|
53 |
app = [
|
54 |
"streamlit==1.38.0",
|
@@ -62,7 +63,7 @@ test = [
|
|
62 |
"matgl==1.1.2",
|
63 |
"dgl==2.4.0",
|
64 |
"chgnet==0.3.8",
|
65 |
-
"fairchem-core==1.
|
66 |
"sevenn==0.9.3.post1",
|
67 |
"orb-models==0.3.1",
|
68 |
"alignn==2024.5.27",
|
|
|
44 |
"dgl==2.4.0",
|
45 |
"mace-torch==0.3.4",
|
46 |
"chgnet==0.3.8",
|
47 |
+
"fairchem-core==1.2.0",
|
48 |
"sevenn==0.9.3.post1",
|
49 |
"orb-models==0.3.1",
|
50 |
"alignn==2024.5.27",
|
51 |
+
"prefect>=3.0.4",
|
52 |
+
"prefect-dask"
|
53 |
]
|
54 |
app = [
|
55 |
"streamlit==1.38.0",
|
|
|
63 |
"matgl==1.1.2",
|
64 |
"dgl==2.4.0",
|
65 |
"chgnet==0.3.8",
|
66 |
+
"fairchem-core==1.2.0",
|
67 |
"sevenn==0.9.3.post1",
|
68 |
"orb-models==0.3.1",
|
69 |
"alignn==2024.5.27",
|
serve/leaderboard.py
CHANGED
@@ -25,6 +25,7 @@ table = pd.DataFrame(
|
|
25 |
"Prediction",
|
26 |
"NVT",
|
27 |
"NPT",
|
|
|
28 |
"Code",
|
29 |
"Paper",
|
30 |
"First Release",
|
@@ -42,6 +43,7 @@ for model in MODELS:
|
|
42 |
"Prediction": metadata.get("prediction", None),
|
43 |
"NVT": "✅" if metadata.get("nvt", False) else "❌",
|
44 |
"NPT": "✅" if metadata.get("npt", False) else "❌",
|
|
|
45 |
"Code": metadata.get("github", None) if metadata else None,
|
46 |
"Paper": metadata.get("doi", None) if metadata else None,
|
47 |
"First Release": metadata.get("date", None),
|
@@ -50,7 +52,6 @@ for model in MODELS:
|
|
50 |
|
51 |
table.set_index("Model", inplace=True)
|
52 |
|
53 |
-
|
54 |
s = table.style.background_gradient(
|
55 |
cmap="PuRd", subset=["Element Coverage"], vmin=0, vmax=120
|
56 |
)
|
|
|
25 |
"Prediction",
|
26 |
"NVT",
|
27 |
"NPT",
|
28 |
+
"Training Set",
|
29 |
"Code",
|
30 |
"Paper",
|
31 |
"First Release",
|
|
|
43 |
"Prediction": metadata.get("prediction", None),
|
44 |
"NVT": "✅" if metadata.get("nvt", False) else "❌",
|
45 |
"NPT": "✅" if metadata.get("npt", False) else "❌",
|
46 |
+
"Training Set": metadata.get("datasets", []),
|
47 |
"Code": metadata.get("github", None) if metadata else None,
|
48 |
"Paper": metadata.get("doi", None) if metadata else None,
|
49 |
"First Release": metadata.get("date", None),
|
|
|
52 |
|
53 |
table.set_index("Model", inplace=True)
|
54 |
|
|
|
55 |
s = table.style.background_gradient(
|
56 |
cmap="PuRd", subset=["Element Coverage"], vmin=0, vmax=120
|
57 |
)
|
serve/ranks/homonuclear-diatomics.py
CHANGED
@@ -72,6 +72,8 @@ table["Rank"] += np.argsort(table["Energy jump [eV]"].to_numpy())
|
|
72 |
table.sort_values("Force flips", ascending=True, inplace=True)
|
73 |
table["Rank"] += np.argsort(table["Force flips"].to_numpy())
|
74 |
|
|
|
|
|
75 |
table.sort_values("Rank", ascending=True, inplace=True)
|
76 |
|
77 |
table["Rank aggr."] = table["Rank"]
|
|
|
72 |
table.sort_values("Force flips", ascending=True, inplace=True)
|
73 |
table["Rank"] += np.argsort(table["Force flips"].to_numpy())
|
74 |
|
75 |
+
table["Rank"] += 1
|
76 |
+
|
77 |
table.sort_values("Rank", ascending=True, inplace=True)
|
78 |
|
79 |
table["Rank aggr."] = table["Rank"]
|
serve/tasks/homonuclear-diatomics.py
CHANGED
@@ -30,7 +30,7 @@ valid_models = [
|
|
30 |
mlip_methods = container.multiselect(
|
31 |
"MLIPs",
|
32 |
valid_models,
|
33 |
-
["EquiformerV2(OC22)", "CHGNet", "M3GNet", "SevenNet", "MACE-MP(M)", "ORB"],
|
34 |
)
|
35 |
dft_methods = container.multiselect("DFT Methods", ["GPAW"], [])
|
36 |
|
@@ -139,6 +139,7 @@ def get_plots(df, energy_plot: bool, force_plot: bool, method_color_mapping: dic
|
|
139 |
ys = es
|
140 |
|
141 |
elo = min(elo, max(ys.min() * 1.2, -15), -1)
|
|
|
142 |
|
143 |
fig.add_trace(
|
144 |
go.Scatter(
|
@@ -202,7 +203,7 @@ def get_plots(df, energy_plot: bool, force_plot: bool, method_color_mapping: dic
|
|
202 |
yaxis=dict(
|
203 |
title=dict(text="Energy [eV]"),
|
204 |
side="left",
|
205 |
-
range=[elo,
|
206 |
)
|
207 |
)
|
208 |
|
|
|
30 |
mlip_methods = container.multiselect(
|
31 |
"MLIPs",
|
32 |
valid_models,
|
33 |
+
["EquiformerV2(OC22)", "CHGNet", "M3GNet", "SevenNet", "MACE-MP(M)", "ORB", "eqV2(OMat)"],
|
34 |
)
|
35 |
dft_methods = container.multiselect("DFT Methods", ["GPAW"], [])
|
36 |
|
|
|
139 |
ys = es
|
140 |
|
141 |
elo = min(elo, max(ys.min() * 1.2, -15), -1)
|
142 |
+
# elo = min(elo, ys.min()*1.2, -1)
|
143 |
|
144 |
fig.add_trace(
|
145 |
go.Scatter(
|
|
|
203 |
yaxis=dict(
|
204 |
title=dict(text="Energy [eV]"),
|
205 |
side="left",
|
206 |
+
range=[elo, 2.0 * (abs(elo))],
|
207 |
)
|
208 |
)
|
209 |
|
tests/test_eos.py
CHANGED
@@ -1,34 +1,41 @@
|
|
1 |
import pytest
|
2 |
from ase.build import bulk
|
|
|
3 |
|
4 |
-
from mlip_arena.models
|
5 |
from mlip_arena.tasks.eos.run import fit as EOS
|
|
|
6 |
|
7 |
atoms = bulk("Cu", "fcc", a=3.6)
|
8 |
|
|
|
|
|
|
|
|
|
9 |
|
|
|
10 |
@pytest.mark.parametrize("model", [MLIPEnum["MACE-MP(M)"]])
|
11 |
def test_eos(model: MLIPEnum):
|
12 |
"""
|
13 |
Test EOS prefect workflow with a simple cubic lattice.
|
14 |
"""
|
15 |
|
16 |
-
|
17 |
-
atoms=atoms,
|
18 |
-
calculator_name=model.name,
|
19 |
-
calculator_kwargs={},
|
20 |
-
device=None,
|
21 |
-
optimizer="BFGSLineSearch",
|
22 |
-
optimizer_kwargs=None,
|
23 |
-
filter="FrechetCell",
|
24 |
-
filter_kwargs=None,
|
25 |
-
criterion=dict(
|
26 |
-
fmax=0.1,
|
27 |
-
),
|
28 |
-
max_abs_strain=0.1,
|
29 |
-
npoints=6,
|
30 |
-
)
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
|
|
1 |
import pytest
|
2 |
from ase.build import bulk
|
3 |
+
from prefect.testing.utilities import prefect_test_harness
|
4 |
|
5 |
+
from mlip_arena.models import MLIPEnum
|
6 |
from mlip_arena.tasks.eos.run import fit as EOS
|
7 |
+
import sys
|
8 |
|
9 |
atoms = bulk("Cu", "fcc", a=3.6)
|
10 |
|
11 |
+
# @pytest.fixture(autouse=True, scope="session")
|
12 |
+
# def prefect_test_fixture():
|
13 |
+
# with prefect_test_harness():
|
14 |
+
# yield
|
15 |
|
16 |
+
@pytest.mark.skipif(sys.version_info[:2] != (3,11), reason="avoid prefect race condition on concurrent tasks")
|
17 |
@pytest.mark.parametrize("model", [MLIPEnum["MACE-MP(M)"]])
|
18 |
def test_eos(model: MLIPEnum):
|
19 |
"""
|
20 |
Test EOS prefect workflow with a simple cubic lattice.
|
21 |
"""
|
22 |
|
23 |
+
with prefect_test_harness():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
result = EOS(
|
26 |
+
atoms=atoms,
|
27 |
+
calculator_name=model.name,
|
28 |
+
calculator_kwargs={},
|
29 |
+
device=None,
|
30 |
+
optimizer="BFGSLineSearch",
|
31 |
+
optimizer_kwargs=None,
|
32 |
+
filter="FrechetCell",
|
33 |
+
filter_kwargs=None,
|
34 |
+
criterion=dict(
|
35 |
+
fmax=0.1,
|
36 |
+
),
|
37 |
+
max_abs_strain=0.1,
|
38 |
+
npoints=6,
|
39 |
+
)
|
40 |
|
41 |
+
assert isinstance(result["K"], float)
|
tests/test_external_calculators.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import pytest
|
2 |
from ase import Atoms
|
3 |
|
4 |
-
from mlip_arena.models
|
5 |
|
6 |
|
7 |
@pytest.mark.parametrize("model", MLIPEnum)
|
|
|
1 |
import pytest
|
2 |
from ase import Atoms
|
3 |
|
4 |
+
from mlip_arena.models import MLIPEnum
|
5 |
|
6 |
|
7 |
@pytest.mark.parametrize("model", MLIPEnum)
|