cyrusyc commited on
Commit
7cbf186
1 Parent(s): 0ffedd3

refactor external calculators; better handle devices

Browse files
mlip_arena/models/__init__.py CHANGED
@@ -18,37 +18,29 @@ class MLIP(
18
  PyTorchModelHubMixin,
19
  tags=["atomistic-simulation", "MLIP"],
20
  ):
21
- def __init__(self, *args, **kwargs) -> None:
22
- super().__init__(*args, **kwargs)
23
-
24
-
25
- class ModuleMLIP(MLIP):
26
- def __init__(self, model: nn.Module, *args, **kwargs) -> None:
27
- super().__init__(*args, **kwargs)
28
- self.add_module("model", model)
29
 
30
  def forward(self, x):
31
- print("Forwarding...")
32
- out = self.model(x)
33
- print("Forwarded!")
34
- return out
35
-
36
 
37
- class MLIPCalculator(Calculator):
38
  name: str
39
- # device: torch.device
40
- # model: MLIP
41
  implemented_properties: list[str] = ["energy", "forces", "stress"]
42
 
43
  def __init__(
44
  self,
 
45
  # ASE Calculator
46
  restart=None,
47
  atoms=None,
48
  directory=".",
49
- **kwargs,
50
  ):
51
- super().__init__(restart=restart, atoms=atoms, directory=directory, **kwargs)
 
 
52
  # self.name: str = self.__class__.__name__
53
  # self.device = device or torch.device(
54
  # "cuda" if torch.cuda.is_available() else "cpu"
 
18
  PyTorchModelHubMixin,
19
  tags=["atomistic-simulation", "MLIP"],
20
  ):
21
+ def __init__(self, model: nn.Module) -> None:
22
+ super().__init__()
23
+ self.model = model
 
 
 
 
 
24
 
25
  def forward(self, x):
26
+ return self.model(x)
 
 
 
 
27
 
28
+ class MLIPCalculator(MLIP, Calculator):
29
  name: str
 
 
30
  implemented_properties: list[str] = ["energy", "forces", "stress"]
31
 
32
  def __init__(
33
  self,
34
+ model,
35
  # ASE Calculator
36
  restart=None,
37
  atoms=None,
38
  directory=".",
39
+ calculator_kwargs: dict = {},
40
  ):
41
+ MLIP.__init__(self, model=model) # Initialize MLIP part
42
+ Calculator.__init__(self, restart=restart, atoms=atoms, directory=directory, **calculator_kwargs) # Initialize ASE Calculator part
43
+ # Additional initialization if needed
44
  # self.name: str = self.__class__.__name__
45
  # self.device = device or torch.device(
46
  # "cuda" if torch.cuda.is_available() else "cpu"
mlip_arena/models/chgnet.py CHANGED
@@ -7,10 +7,12 @@ from ase.calculators.calculator import all_changes
7
  from huggingface_hub import hf_hub_download
8
  from torch_geometric.data import Data
9
 
10
- from mlip_arena.models import MLIP, MLIPCalculator, ModuleMLIP
11
 
 
12
 
13
- class CHGNetCalculator(MLIPCalculator):
 
14
  def __init__(
15
  self,
16
  device: torch.device | None = None,
@@ -19,23 +21,15 @@ class CHGNetCalculator(MLIPCalculator):
19
  directory=".",
20
  **kwargs,
21
  ):
22
- super().__init__(restart=restart, atoms=atoms, directory=directory, **kwargs)
23
-
24
- self.name: str = self.__class__.__name__
25
-
26
- fpath = hf_hub_download(
27
- repo_id="cyrusyc/mace-universal",
28
- subfolder="pretrained",
29
- filename="2023-12-12-mace-128-L1_epoch-199.model",
30
- revision="main",
31
- )
32
-
33
  self.device = device or torch.device(
34
  "cuda" if torch.cuda.is_available() else "cpu"
35
  )
36
 
37
- self.model = torch.load(fpath, map_location=self.device)
 
 
38
 
 
39
  self.implemented_properties = ["energy", "forces", "stress"]
40
 
41
  def calculate(
 
7
  from huggingface_hub import hf_hub_download
8
  from torch_geometric.data import Data
9
 
10
+ from mlip_arena.models import MLIP, MLIPCalculator
11
 
12
+ # TODO: WIP
13
 
14
+
15
+ class CHGNet(MLIPCalculator):
16
  def __init__(
17
  self,
18
  device: torch.device | None = None,
 
21
  directory=".",
22
  **kwargs,
23
  ):
 
 
 
 
 
 
 
 
 
 
 
24
  self.device = device or torch.device(
25
  "cuda" if torch.cuda.is_available() else "cpu"
26
  )
27
 
28
+ super().__init__(
29
+ model=model, restart=restart, atoms=atoms, directory=directory, **kwargs
30
+ )
31
 
32
+ self.name: str = self.__class__.__name__
33
  self.implemented_properties = ["energy", "forces", "stress"]
34
 
35
  def calculate(
mlip_arena/models/externals.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib
3
+ from typing import Literal
4
+
5
+ import torch
6
+ from alignn.ff.ff import AlignnAtomwiseCalculator, get_figshare_model_ff
7
+ from ase import Atoms
8
+ from chgnet.model.dynamics import CHGNetCalculator
9
+ from chgnet.model.model import CHGNet
10
+ from fairchem.core import OCPCalculator
11
+ from mace.calculators import MACECalculator
12
+
13
+
14
+ # Avoid circular import
15
+ def get_freer_device() -> torch.device:
16
+ """Get the GPU with the most free memory, or use MPS if available.
17
+ s
18
+ Returns:
19
+ torch.device: The selected GPU device or MPS.
20
+
21
+ Raises:
22
+ ValueError: If no GPU or MPS is available.
23
+ """
24
+ device_count = torch.cuda.device_count()
25
+ if device_count > 0:
26
+ # If CUDA GPUs are available, select the one with the most free memory
27
+ mem_free = [
28
+ torch.cuda.get_device_properties(i).total_memory
29
+ - torch.cuda.memory_allocated(i)
30
+ for i in range(device_count)
31
+ ]
32
+ free_gpu_index = mem_free.index(max(mem_free))
33
+ device = torch.device(f"cuda:{free_gpu_index}")
34
+ print(
35
+ f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
36
+ )
37
+ elif torch.backends.mps.is_available():
38
+ # If no CUDA GPUs are available but MPS is, use MPS
39
+ print("No GPU available. Using MPS.")
40
+ device = torch.device("mps")
41
+ else:
42
+ # Fallback to CPU if neither CUDA GPUs nor MPS are available
43
+ print("No GPU or MPS available. Using CPU.")
44
+ device = torch.device("cpu")
45
+
46
+ return device
47
+
48
+
49
+ class MACE_MP_Medium(MACECalculator):
50
+ def __init__(self, device=None, default_dtype="float32", **kwargs):
51
+ checkpoint_url = "http://tinyurl.com/5yyxdm76"
52
+ cache_dir = os.path.expanduser("~/.cache/mace")
53
+ checkpoint_url_name = "".join(
54
+ c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
55
+ )
56
+ cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
57
+ if not os.path.isfile(cached_model_path):
58
+ os.makedirs(cache_dir, exist_ok=True)
59
+ # download and save to disk
60
+ print(f"Downloading MACE model from {checkpoint_url!r}")
61
+ _, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path)
62
+ if "Content-Type: text/html" in http_msg:
63
+ raise RuntimeError(
64
+ f"Model download failed, please check the URL {checkpoint_url}"
65
+ )
66
+ print(f"Cached MACE model to {cached_model_path}")
67
+ model = cached_model_path
68
+ msg = f"Using Materials Project MACE for MACECalculator with {model}"
69
+ print(msg)
70
+
71
+ device = device or str(get_freer_device())
72
+
73
+ super().__init__(
74
+ model_paths=model, device=device, default_dtype=default_dtype, **kwargs
75
+ )
76
+
77
+
78
+ class CHGNet(CHGNetCalculator):
79
+ def __init__(
80
+ self,
81
+ model: CHGNet | None = None,
82
+ use_device: str | None = None,
83
+ stress_weight: float | None = 1 / 160.21766208,
84
+ on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
85
+ **kwargs,
86
+ ) -> None:
87
+ use_device = use_device or str(get_freer_device())
88
+ super().__init__(
89
+ model=model,
90
+ use_device=use_device,
91
+ stress_weight=stress_weight,
92
+ on_isolated_atoms=on_isolated_atoms,
93
+ **kwargs,
94
+ )
95
+
96
+ def calculate(
97
+ self,
98
+ atoms: Atoms | None = None,
99
+ properties: list | None = None,
100
+ system_changes: list | None = None,
101
+ ) -> None:
102
+ super().calculate(atoms, properties, system_changes)
103
+
104
+ # for ase.io.write compatibility
105
+ self.results.pop("crystal_fea", None)
106
+
107
+
108
+ class EquiformerV2(OCPCalculator):
109
+ def __init__(
110
+ self,
111
+ model_name="EquiformerV2-lE4-lF100-S2EFS-OC22",
112
+ local_cache="/tmp/ocp/",
113
+ cpu=False,
114
+ seed=0,
115
+ **kwargs,
116
+ ) -> None:
117
+ super().__init__(
118
+ model_name=model_name,
119
+ local_cache=local_cache,
120
+ cpu=cpu,
121
+ seed=0,
122
+ **kwargs,
123
+ )
124
+
125
+ def calculate(self, atoms: Atoms, properties, system_changes) -> None:
126
+ super().calculate(atoms, properties, system_changes)
127
+
128
+ self.results.update(
129
+ force=atoms.get_forces(),
130
+ )
131
+
132
+
133
+ class eSCN(OCPCalculator):
134
+ def __init__(
135
+ self,
136
+ model_name="eSCN-L6-M3-Lay20-S2EF-OC20-All+MD",
137
+ local_cache="/tmp/ocp/",
138
+ cpu=False,
139
+ seed=0,
140
+ **kwargs,
141
+ ) -> None:
142
+ super().__init__(
143
+ model_name=model_name,
144
+ local_cache=local_cache,
145
+ cpu=cpu,
146
+ seed=0,
147
+ **kwargs,
148
+ )
149
+
150
+ def calculate(self, atoms: Atoms, properties, system_changes) -> None:
151
+ super().calculate(atoms, properties, system_changes)
152
+
153
+ self.results.update(
154
+ force=atoms.get_forces(),
155
+ )
156
+
157
+
158
+ class ALIGNN(AlignnAtomwiseCalculator):
159
+ def __init__(self, dir_path: str = "/tmp/alignn/", device=None, **kwargs) -> None:
160
+ model_path = get_figshare_model_ff(dir_path=dir_path)
161
+ device = device or get_freer_device()
162
+ super().__init__(model_path=model_path, device=device, **kwargs)
163
+
164
+ def calculate(self, atoms, properties=None, system_changes=None):
165
+ super().calculate(atoms, properties, system_changes)
mlip_arena/models/mace.py CHANGED
@@ -1,13 +1,10 @@
1
- from typing import Optional, Tuple
2
-
3
- import numpy as np
4
  import torch
5
  from ase import Atoms
6
  from ase.calculators.calculator import all_changes
7
  from huggingface_hub import hf_hub_download
8
  from torch_geometric.data import Data
9
 
10
- from mlip_arena.models import MLIP, MLIPCalculator, ModuleMLIP
11
 
12
 
13
  class MACE_MP_Medium(MLIPCalculator):
@@ -19,9 +16,9 @@ class MACE_MP_Medium(MLIPCalculator):
19
  directory=".",
20
  **kwargs,
21
  ):
22
- super().__init__(restart=restart, atoms=atoms, directory=directory, **kwargs)
23
-
24
- self.name: str = self.__class__.__name__
25
 
26
  fpath = hf_hub_download(
27
  repo_id="cyrusyc/mace-universal",
@@ -30,23 +27,15 @@ class MACE_MP_Medium(MLIPCalculator):
30
  revision="main",
31
  )
32
 
33
- self.device = device or torch.device(
34
- "cuda" if torch.cuda.is_available() else "cpu"
35
- )
36
 
37
- self.model = torch.load(fpath, map_location=self.device)
 
 
38
 
 
39
  self.implemented_properties = ["energy", "forces", "stress"]
40
 
41
- # repo_id = f"atomind/{self.__class__.__name__}".lower().replace("_", "-")
42
-
43
- # model = ModuleMLIP(model=model)
44
- # model.save_pretrained(
45
- # self.__class__.__name__.lower().replace("_", "-"),
46
- # repo_id=repo_id,
47
- # push_to_hub=True,
48
- # )
49
-
50
  def calculate(
51
  self, atoms: Atoms, properties: list[str], system_changes: list = all_changes
52
  ):
 
 
 
 
1
  import torch
2
  from ase import Atoms
3
  from ase.calculators.calculator import all_changes
4
  from huggingface_hub import hf_hub_download
5
  from torch_geometric.data import Data
6
 
7
+ from mlip_arena.models import MLIPCalculator
8
 
9
 
10
  class MACE_MP_Medium(MLIPCalculator):
 
16
  directory=".",
17
  **kwargs,
18
  ):
19
+ self.device = device or torch.device(
20
+ "cuda" if torch.cuda.is_available() else "cpu"
21
+ )
22
 
23
  fpath = hf_hub_download(
24
  repo_id="cyrusyc/mace-universal",
 
27
  revision="main",
28
  )
29
 
30
+ model = torch.load(fpath, map_location=self.device)
 
 
31
 
32
+ super().__init__(
33
+ model=model, restart=restart, atoms=atoms, directory=directory, **kwargs
34
+ )
35
 
36
+ self.name: str = self.__class__.__name__
37
  self.implemented_properties = ["energy", "forces", "stress"]
38
 
 
 
 
 
 
 
 
 
 
39
  def calculate(
40
  self, atoms: Atoms, properties: list[str], system_changes: list = all_changes
41
  ):
mlip_arena/models/registry.yaml CHANGED
@@ -1,7 +1,7 @@
1
 
2
 
3
  MACE-MP(M):
4
- module: mace
5
  class: MACE_MP_Medium
6
  username: cyrusyc # HF username
7
  last-update: 2024-03-25T14:30:00
@@ -17,6 +17,40 @@ MACE-MP(M):
17
  doi: https://arxiv.org/abs/2401.00096
18
  date: 2023-12-29
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # CHGNet:
21
  # module: chgnet
22
  # username: cyrusyc
 
1
 
2
 
3
  MACE-MP(M):
4
+ module: externals
5
  class: MACE_MP_Medium
6
  username: cyrusyc # HF username
7
  last-update: 2024-03-25T14:30:00
 
17
  doi: https://arxiv.org/abs/2401.00096
18
  date: 2023-12-29
19
 
20
+ CHGNet:
21
+ module: externals
22
+ class: CHGNet
23
+ username: cyrusyc
24
+ last-update: 2024-07-08T00:00:00
25
+ datetime: 2024-07-08T00:00:00
26
+ datasets:
27
+ - atomind/mptrj
28
+ gpu-tasks:
29
+ - diatomics
30
+
31
+ EquiformerV2(OC22):
32
+ module: externals
33
+ class: EquiformerV2
34
+ username: cyrusyc
35
+ last-update: 2024-07-08T00:00:00
36
+ datetime: 2024-07-08T00:00:00
37
+ datasets:
38
+ - ocp
39
+ gpu-tasks:
40
+ - diatomics
41
+
42
+ eSCN(OC20):
43
+ module: externals
44
+ class: eSCN
45
+ username: cyrusyc
46
+ last-update: 2024-07-08T00:00:00
47
+ datetime: 2024-07-08T00:00:00
48
+ datasets:
49
+ - ocp
50
+ gpu-tasks:
51
+ - diatomics
52
+
53
+
54
  # CHGNet:
55
  # module: chgnet
56
  # username: cyrusyc
mlip_arena/models/utils.py CHANGED
@@ -2,10 +2,8 @@
2
 
3
  import importlib
4
  from enum import Enum
5
- from typing import Any
6
 
7
  import torch
8
- from ase.calculators.calculator import Calculator
9
 
10
  from mlip_arena.models import REGISTRY
11
 
@@ -15,73 +13,112 @@ MLIPMap = {
15
  )
16
  for model, metadata in REGISTRY.items()
17
  }
 
18
 
19
 
20
- class EXTMLIPEnum(Enum):
21
- """Enumeration class for EXTMLIP models.
 
 
 
22
 
23
- Attributes:
24
- M3GNet (str): M3GNet model.
25
- CHGNet (str): CHGNet model.
26
- MACE (str): MACE model.
27
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- M3GNet = "M3GNet"
30
- CHGNet = "CHGNet"
31
- MACE = "MACE"
32
 
33
 
34
- def get_freer_device() -> torch.device:
35
- """Get the GPU with the most free memory.
36
 
37
- Returns:
38
- torch.device: The selected GPU device.
 
 
 
39
 
40
- Raises:
41
- ValueError: If no GPU is available.
42
- """
43
- device_count = torch.cuda.device_count()
44
- if device_count == 0:
45
- print("No GPU available. Using CPU.")
46
- return torch.device("cpu")
47
 
48
- mem_free = [
49
- torch.cuda.get_device_properties(i).total_memory
50
- - torch.cuda.memory_allocated(i)
51
- for i in range(device_count)
52
- ]
53
 
54
- free_gpu_index = mem_free.index(max(mem_free))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- print(
57
- f"Selected GPU {free_gpu_index} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs",
58
- )
59
 
60
- return torch.device(f"cuda:{free_gpu_index}")
 
 
 
61
 
 
 
62
 
63
- def external_ase_calculator(name: EXTMLIPEnum, **kwargs: Any) -> Calculator:
64
- """Construct an ASE calculator from an external third-party MLIP packages"""
65
- calculator = None
66
- device = get_freer_device()
67
 
68
- if name == EXTMLIPEnum.MACE:
69
- from mace.calculators import mace_mp
70
 
71
- calculator = mace_mp(device=str(device), **kwargs)
72
 
73
- elif name == EXTMLIPEnum.CHGNet:
74
- from chgnet.model.dynamics import CHGNetCalculator
 
75
 
76
- calculator = CHGNetCalculator(use_device=str(device), **kwargs)
 
77
 
78
- elif name == EXTMLIPEnum.M3GNet:
79
- import matgl
80
- from matgl.ext.ase import PESCalculator
81
 
82
- potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
83
- calculator = PESCalculator(potential, **kwargs)
84
 
85
- calculator.__setattr__("name", name.value)
86
 
87
- return calculator
 
2
 
3
  import importlib
4
  from enum import Enum
 
5
 
6
  import torch
 
7
 
8
  from mlip_arena.models import REGISTRY
9
 
 
13
  )
14
  for model, metadata in REGISTRY.items()
15
  }
16
+ MLIPEnum = Enum("MLIPEnum", MLIPMap)
17
 
18
 
19
+ def get_freer_device() -> torch.device:
20
+ """Get the GPU with the most free memory, or use MPS if available.
21
+ s
22
+ Returns:
23
+ torch.device: The selected GPU device or MPS.
24
 
25
+ Raises:
26
+ ValueError: If no GPU or MPS is available.
 
 
27
  """
28
+ device_count = torch.cuda.device_count()
29
+ if device_count > 0:
30
+ # If CUDA GPUs are available, select the one with the most free memory
31
+ mem_free = [
32
+ torch.cuda.get_device_properties(i).total_memory
33
+ - torch.cuda.memory_allocated(i)
34
+ for i in range(device_count)
35
+ ]
36
+ free_gpu_index = mem_free.index(max(mem_free))
37
+ device = torch.device(f"cuda:{free_gpu_index}")
38
+ print(
39
+ f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
40
+ )
41
+ elif torch.backends.mps.is_available():
42
+ # If no CUDA GPUs are available but MPS is, use MPS
43
+ print("No GPU available. Using MPS.")
44
+ device = torch.device("mps")
45
+ else:
46
+ # Fallback to CPU if neither CUDA GPUs nor MPS are available
47
+ print("No GPU or MPS available. Using CPU.")
48
+ device = torch.device("cpu")
49
 
50
+ return device
 
 
51
 
52
 
53
+ # class EXTMLIPEnum(Enum):
54
+ # """Enumeration class for EXTMLIP models.
55
 
56
+ # Attributes:
57
+ # M3GNet (str): M3GNet model.
58
+ # CHGNet (str): CHGNet model.
59
+ # MACE (str): MACE model.
60
+ # """
61
 
62
+ # M3GNet = "M3GNet"
63
+ # CHGNet = "CHGNet"
64
+ # MACE = "MACE"
65
+ # Equiformer = "Equiformer"
 
 
 
66
 
 
 
 
 
 
67
 
68
+ # def get_freer_device() -> torch.device:
69
+ # """Get the GPU with the most free memory.
70
+
71
+ # Returns:
72
+ # torch.device: The selected GPU device.
73
+
74
+ # Raises:
75
+ # ValueError: If no GPU is available.
76
+ # """
77
+ # device_count = torch.cuda.device_count()
78
+ # if device_count == 0:
79
+ # print("No GPU available. Using CPU.")
80
+ # return torch.device("cpu")
81
+
82
+ # mem_free = [
83
+ # torch.cuda.get_device_properties(i).total_memory
84
+ # - torch.cuda.memory_allocated(i)
85
+ # for i in range(device_count)
86
+ # ]
87
+
88
+ # free_gpu_index = mem_free.index(max(mem_free))
89
+
90
+ # print(
91
+ # f"Selected GPU {free_gpu_index} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs",
92
+ # )
93
+
94
+ # return torch.device(f"cuda:{free_gpu_index}")
95
+
96
 
 
 
 
97
 
98
+ # def external_ase_calculator(name: EXTMLIPEnum, **kwargs: Any) -> Calculator:
99
+ # """Construct an ASE calculator from an external third-party MLIP packages"""
100
+ # calculator = None
101
+ # device = get_freer_device()
102
 
103
+ # if name == EXTMLIPEnum.MACE:
104
+ # from mace.calculators import mace_mp
105
 
106
+ # calculator = mace_mp(device=str(device), **kwargs)
 
 
 
107
 
108
+ # elif name == EXTMLIPEnum.CHGNet:
109
+ # from chgnet.model.dynamics import CHGNetCalculator
110
 
111
+ # calculator = CHGNetCalculator(use_device=str(device), **kwargs)
112
 
113
+ # elif name == EXTMLIPEnum.M3GNet:
114
+ # import matgl
115
+ # from matgl.ext.ase import PESCalculator
116
 
117
+ # potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
118
+ # calculator = PESCalculator(potential, **kwargs)
119
 
 
 
 
120
 
 
 
121
 
122
+ # calculator.__setattr__("name", name.value)
123
 
124
+ # return calculator
mlip_arena/tasks/stability/__init__.py CHANGED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ from .run import md as MD
3
+
mlip_arena/tasks/stability/run.py CHANGED
@@ -1,39 +1,141 @@
1
  from __future__ import annotations
2
 
3
- import datetime
4
- from datetime import datetime
5
  from pathlib import Path
6
- from typing import Literal, Sequence
7
 
8
  import numpy as np
9
- import torch
10
  from ase import Atoms, units
 
11
  from ase.calculators.mixing import SumCalculator
12
  from ase.io import read
13
  from ase.io.trajectory import Trajectory
 
 
14
  from ase.md.md import MolecularDynamics
15
  from ase.md.npt import NPT
 
 
16
  from ase.md.velocitydistribution import (
17
  MaxwellBoltzmannDistribution,
18
  Stationary,
19
  ZeroRotation,
20
  )
 
 
 
 
21
  from scipy.linalg import schur
22
  from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
23
  from tqdm.auto import tqdm
24
 
25
- from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
26
- from mlip_arena.tasks.utils import (
27
- _get_ensemble_defaults,
28
- _get_ensemble_schedule,
29
- _preset_dynamics,
30
- _valid_dynamics,
31
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
 
34
  def md(
35
  atoms: Atoms,
36
- calculator_name: str | EXTMLIPEnum,
37
  calculator_kwargs: dict | None,
38
  dispersion: str | None = None,
39
  dispersion_kwargs: dict | None = None,
@@ -59,16 +161,19 @@ def md(
59
  # device: str | None = None,
60
  # dtype: str = "float64",
61
  ):
62
- device = device or ("cuda" if torch.cuda.is_available() else "cpu")
63
 
64
  print(f"Using device: {device}")
65
 
66
  calculator_kwargs = calculator_kwargs or {}
67
 
68
- if isinstance(calculator_name, EXTMLIPEnum) and calculator_name in EXTMLIPEnum:
69
- calc = external_ase_calculator(calculator_name, **calculator_kwargs)
70
- elif calculator_name in MLIPMap:
71
- calc = MLIPMap[calculator_name](**calculator_kwargs)
 
 
 
72
 
73
  print(f"Using calculator: {calc}")
74
 
@@ -171,12 +276,15 @@ def md(
171
  with tqdm(total=n_steps) as pbar:
172
 
173
  def _callback(dyn: MolecularDynamics = md_runner) -> None:
 
 
 
174
  if ensemble == "nve":
175
  return
176
- dyn.set_temperature(temperature_K=t_schedule[last_step + dyn.nsteps])
177
  if ensemble == "nvt":
178
  return
179
- dyn.set_stress(p_schedule[last_step + dyn.nsteps] * 1e3 * units.bar)
180
  pbar.update()
181
 
182
  md_runner.attach(_callback, interval=1)
@@ -187,4 +295,4 @@ def md(
187
 
188
  traj.close()
189
 
190
- return {"md_runtime": end_time - start_time}
 
1
  from __future__ import annotations
2
 
3
+ from datetime import datetime, timedelta
 
4
  from pathlib import Path
5
+ from typing import Literal, Sequence, Tuple
6
 
7
  import numpy as np
 
8
  from ase import Atoms, units
9
+ from ase.calculators.calculator import Calculator
10
  from ase.calculators.mixing import SumCalculator
11
  from ase.io import read
12
  from ase.io.trajectory import Trajectory
13
+ from ase.md.andersen import Andersen
14
+ from ase.md.langevin import Langevin
15
  from ase.md.md import MolecularDynamics
16
  from ase.md.npt import NPT
17
+ from ase.md.nptberendsen import NPTBerendsen
18
+ from ase.md.nvtberendsen import NVTBerendsen
19
  from ase.md.velocitydistribution import (
20
  MaxwellBoltzmannDistribution,
21
  Stationary,
22
  ZeroRotation,
23
  )
24
+ from ase.md.verlet import VelocityVerlet
25
+ from prefect import task
26
+ from prefect.tasks import task_input_hash
27
+ from scipy.interpolate import interp1d
28
  from scipy.linalg import schur
29
  from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
30
  from tqdm.auto import tqdm
31
 
32
+ from mlip_arena.models.utils import MLIPEnum, get_freer_device
33
+
34
+ # from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
35
+
36
+ _valid_dynamics: dict[str, tuple[str, ...]] = {
37
+ "nve": ("velocityverlet",),
38
+ "nvt": ("nose-hoover", "langevin", "andersen", "berendsen"),
39
+ "npt": ("nose-hoover", "berendsen"),
40
+ }
41
+
42
+ _preset_dynamics: dict = {
43
+ "nve_velocityverlet": VelocityVerlet,
44
+ "nvt_andersen": Andersen,
45
+ "nvt_berendsen": NVTBerendsen,
46
+ "nvt_langevin": Langevin,
47
+ "nvt_nose-hoover": NPT,
48
+ "npt_berendsen": NPTBerendsen,
49
+ "npt_nose-hoover": NPT,
50
+ }
51
+
52
+ def _interpolate_quantity(values: Sequence | np.ndarray, n_pts: int) -> np.ndarray:
53
+ """Interpolate temperature / pressure on a schedule."""
54
+ n_vals = len(values)
55
+ return np.interp(
56
+ np.linspace(0, n_vals - 1, n_pts + 1),
57
+ np.linspace(0, n_vals - 1, n_vals),
58
+ values,
59
+ )
60
+
61
+ def _get_ensemble_schedule(
62
+ ensemble: Literal["nve", "nvt", "npt"] = "nvt",
63
+ n_steps: int = 1000,
64
+ temperature: float | Sequence | np.ndarray | None = 300.0,
65
+ pressure: float | Sequence | np.ndarray | None = None
66
+ ) -> Tuple[np.ndarray, np.ndarray]:
67
+ if ensemble == "nve":
68
+ # Disable thermostat and barostat
69
+ temperature = np.nan
70
+ pressure = np.nan
71
+ t_schedule = np.full(n_steps + 1, temperature)
72
+ p_schedule = np.full(n_steps + 1, pressure)
73
+ return t_schedule, p_schedule
74
+
75
+ if isinstance(temperature, Sequence) or (
76
+ isinstance(temperature, np.ndarray) and temperature.ndim == 1
77
+ ):
78
+ t_schedule = _interpolate_quantity(temperature, n_steps)
79
+ # NOTE: In ASE Langevin dynamics, the temperature are normally
80
+ # scalars, but in principle one quantity per atom could be specified by giving
81
+ # an array. This is not implemented yet here.
82
+ else:
83
+ t_schedule = np.full(n_steps + 1, temperature)
84
+
85
+ if ensemble == "nvt":
86
+ pressure = np.nan
87
+ p_schedule = np.full(n_steps + 1, pressure)
88
+ return t_schedule, p_schedule
89
+
90
+ if isinstance(pressure, Sequence) or (
91
+ isinstance(pressure, np.ndarray) and pressure.ndim == 1
92
+ ):
93
+ p_schedule = _interpolate_quantity(pressure, n_steps)
94
+ elif isinstance(pressure, np.ndarray) and pressure.ndim == 4:
95
+ p_schedule = interp1d(
96
+ np.arange(n_steps + 1), pressure, kind="linear"
97
+ )
98
+ assert isinstance(p_schedule, np.ndarray)
99
+ else:
100
+ p_schedule = np.full(n_steps + 1, pressure)
101
+
102
+ return t_schedule, p_schedule
103
+
104
+ def _get_ensemble_defaults(
105
+ ensemble: Literal["nve", "nvt", "npt"],
106
+ dynamics: str | MolecularDynamics,
107
+ t_schedule: np.ndarray,
108
+ p_schedule: np.ndarray,
109
+ ase_md_kwargs: dict | None = None) -> dict:
110
+ """Update ASE MD kwargs"""
111
+ ase_md_kwargs = ase_md_kwargs or {}
112
+
113
+ if ensemble == "nve":
114
+ ase_md_kwargs.pop("temperature", None)
115
+ ase_md_kwargs.pop("temperature_K", None)
116
+ ase_md_kwargs.pop("externalstress", None)
117
+ elif ensemble == "nvt":
118
+ ase_md_kwargs["temperature_K"] = t_schedule[0]
119
+ ase_md_kwargs.pop("externalstress", None)
120
+ elif ensemble == "npt":
121
+ ase_md_kwargs["temperature_K"] = t_schedule[0]
122
+ ase_md_kwargs["externalstress"] = p_schedule[0] * 1e3 * units.bar
123
+
124
+ if isinstance(dynamics, str) and dynamics.lower() == "langevin":
125
+ ase_md_kwargs["friction"] = ase_md_kwargs.get(
126
+ "friction",
127
+ 10.0 * 1e-3 / units.fs, # Same default as in VASP: 10 ps^-1
128
+ )
129
+
130
+ return ase_md_kwargs
131
+
132
+
133
 
134
 
135
+ @task(cache_key_fn=task_input_hash, cache_expiration=timedelta(days=1))
136
  def md(
137
  atoms: Atoms,
138
+ calculator_name: str | MLIPEnum,
139
  calculator_kwargs: dict | None,
140
  dispersion: str | None = None,
141
  dispersion_kwargs: dict | None = None,
 
161
  # device: str | None = None,
162
  # dtype: str = "float64",
163
  ):
164
+ device = device or str(get_freer_device())
165
 
166
  print(f"Using device: {device}")
167
 
168
  calculator_kwargs = calculator_kwargs or {}
169
 
170
+ if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
171
+ assert issubclass(calculator_name.value, Calculator)
172
+ calc = calculator_name.value(**calculator_kwargs)
173
+ elif isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_:
174
+ calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
175
+ else:
176
+ raise ValueError(f"Invalid calculator: {calculator_name}")
177
 
178
  print(f"Using calculator: {calc}")
179
 
 
276
  with tqdm(total=n_steps) as pbar:
277
 
278
  def _callback(dyn: MolecularDynamics = md_runner) -> None:
279
+ step = last_step + dyn.nsteps
280
+ dyn.atoms.info["datetime"] = datetime.now()
281
+ dyn.atoms.info["step"] = step
282
  if ensemble == "nve":
283
  return
284
+ dyn.set_temperature(temperature_K=t_schedule[step])
285
  if ensemble == "nvt":
286
  return
287
+ dyn.set_stress(p_schedule[step] * 1e3 * units.bar)
288
  pbar.update()
289
 
290
  md_runner.attach(_callback, interval=1)
 
295
 
296
  traj.close()
297
 
298
+ return {"runtime": end_time - start_time, "n_steps": n_steps}
mlip_arena/tasks/utils.py DELETED
@@ -1,161 +0,0 @@
1
- import os, glob
2
- from pathlib import Path
3
- from ase.io import read, write
4
- from ase import units
5
- from ase import Atoms, units
6
- from ase.calculators.calculator import Calculator
7
- from ase.data import chemical_symbols
8
- from ase.md.andersen import Andersen
9
- from ase.md.langevin import Langevin
10
- from ase.md.md import MolecularDynamics
11
- from ase.md.npt import NPT
12
- from ase.md.nptberendsen import NPTBerendsen
13
- from ase.md.nvtberendsen import NVTBerendsen
14
- from ase.md.velocitydistribution import (
15
- MaxwellBoltzmannDistribution,
16
- Stationary,
17
- ZeroRotation,
18
- )
19
- from ase.md.verlet import VelocityVerlet
20
- from dask.distributed import Client
21
- from dask_jobqueue import SLURMCluster
22
- from jobflow import Maker
23
- from prefect import flow, task
24
- from prefect.tasks import task_input_hash
25
- from prefect_dask import DaskTaskRunner
26
- from pymatgen.io.ase import AseAtomsAdaptor
27
- from scipy.interpolate import interp1d
28
- from scipy.linalg import schur
29
-
30
- from mlip_arena.models import MLIPCalculator
31
- from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
32
- from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
33
- from mp_api.client import MPRester
34
-
35
- from fireworks import LaunchPad
36
- from atomate2.vasp.flows.core import RelaxBandStructureMaker
37
- from atomate2.vasp.flows.mp import MPGGADoubleRelaxStaticMaker
38
- from atomate2.vasp.powerups import add_metadata_to_flow
39
- from atomate2.forcefields.md import (
40
- CHGNetMDMaker,
41
- GAPMDMaker,
42
- M3GNetMDMaker,
43
- MACEMDMaker,
44
- NequipMDMaker,
45
- )
46
- from atomate2.forcefields.utils import MLFF
47
- from pymatgen.io.ase import AseAtomsAdaptor
48
- from pymatgen.transformations.advanced_transformations import CubicSupercellTransformation
49
- from jobflow.managers.fireworks import flow_to_workflow
50
- from jobflow import run_locally, SETTINGS
51
- from tqdm.auto import tqdm
52
-
53
- from datetime import timedelta, datetime
54
- from typing import Literal, Sequence, Tuple
55
-
56
- import numpy as np
57
- import torch
58
- from pymatgen.core.structure import Structure
59
-
60
- from ase.calculators.mixing import SumCalculator
61
- from scipy.interpolate import interp1d
62
-
63
- from ase.io.trajectory import Trajectory
64
-
65
-
66
- _valid_dynamics: dict[str, tuple[str, ...]] = {
67
- "nve": ("velocityverlet",),
68
- "nvt": ("nose-hoover", "langevin", "andersen", "berendsen"),
69
- "npt": ("nose-hoover", "berendsen"),
70
- }
71
-
72
- _preset_dynamics: dict = {
73
- "nve_velocityverlet": VelocityVerlet,
74
- "nvt_andersen": Andersen,
75
- "nvt_berendsen": NVTBerendsen,
76
- "nvt_langevin": Langevin,
77
- "nvt_nose-hoover": NPT,
78
- "npt_berendsen": NPTBerendsen,
79
- "npt_nose-hoover": NPT,
80
- }
81
-
82
- def _interpolate_quantity(values: Sequence | np.ndarray, n_pts: int) -> np.ndarray:
83
- """Interpolate temperature / pressure on a schedule."""
84
- n_vals = len(values)
85
- return np.interp(
86
- np.linspace(0, n_vals - 1, n_pts + 1),
87
- np.linspace(0, n_vals - 1, n_vals),
88
- values,
89
- )
90
-
91
- def _get_ensemble_schedule(
92
- ensemble: Literal["nve", "nvt", "npt"] = "nvt",
93
- n_steps: int = 1000,
94
- temperature: float | Sequence | np.ndarray | None = 300.0,
95
- pressure: float | Sequence | np.ndarray | None = None
96
- ) -> Tuple[np.ndarray, np.ndarray]:
97
- if ensemble == "nve":
98
- # Disable thermostat and barostat
99
- temperature = np.nan
100
- pressure = np.nan
101
- t_schedule = np.full(n_steps + 1, temperature)
102
- p_schedule = np.full(n_steps + 1, pressure)
103
- return t_schedule, p_schedule
104
-
105
- if isinstance(temperature, Sequence) or (
106
- isinstance(temperature, np.ndarray) and temperature.ndim == 1
107
- ):
108
- t_schedule = _interpolate_quantity(temperature, n_steps)
109
- # NOTE: In ASE Langevin dynamics, the temperature are normally
110
- # scalars, but in principle one quantity per atom could be specified by giving
111
- # an array. This is not implemented yet here.
112
- else:
113
- t_schedule = np.full(n_steps + 1, temperature)
114
-
115
- if ensemble == "nvt":
116
- pressure = np.nan
117
- p_schedule = np.full(n_steps + 1, pressure)
118
- return t_schedule, p_schedule
119
-
120
- if isinstance(pressure, Sequence) or (
121
- isinstance(pressure, np.ndarray) and pressure.ndim == 1
122
- ):
123
- p_schedule = _interpolate_quantity(pressure, n_steps)
124
- elif isinstance(pressure, np.ndarray) and pressure.ndim == 4:
125
- p_schedule = interp1d(
126
- np.arange(n_steps + 1), pressure, kind="linear"
127
- )
128
- assert isinstance(p_schedule, np.ndarray)
129
- else:
130
- p_schedule = np.full(n_steps + 1, pressure)
131
-
132
- return t_schedule, p_schedule
133
-
134
- def _get_ensemble_defaults(
135
- ensemble: Literal["nve", "nvt", "npt"],
136
- dynamics: str | MolecularDynamics,
137
- t_schedule: np.ndarray,
138
- p_schedule: np.ndarray,
139
- ase_md_kwargs: dict | None = None) -> dict:
140
- """Update ASE MD kwargs"""
141
- ase_md_kwargs = ase_md_kwargs or {}
142
-
143
- if ensemble == "nve":
144
- ase_md_kwargs.pop("temperature", None)
145
- ase_md_kwargs.pop("temperature_K", None)
146
- ase_md_kwargs.pop("externalstress", None)
147
- elif ensemble == "nvt":
148
- ase_md_kwargs["temperature_K"] = t_schedule[0]
149
- ase_md_kwargs.pop("externalstress", None)
150
- elif ensemble == "npt":
151
- ase_md_kwargs["temperature_K"] = t_schedule[0]
152
- ase_md_kwargs["externalstress"] = p_schedule[0] * 1e3 * units.bar
153
-
154
- if isinstance(dynamics, str) and dynamics.lower() == "langevin":
155
- ase_md_kwargs["friction"] = ase_md_kwargs.get(
156
- "friction",
157
- 10.0 * 1e-3 / units.fs, # Same default as in VASP: 10 ps^-1
158
- )
159
-
160
- return ase_md_kwargs
161
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/hf_hub.ipynb CHANGED
@@ -2,22 +2,192 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  "metadata": {},
7
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  {
9
  "name": "stderr",
10
  "output_type": "stream",
11
  "text": [
12
- "/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
- " from .autonotebook import tqdm as notebook_tqdm\n"
14
  ]
15
  }
16
  ],
17
  "source": [
18
- "import torch\n",
19
- "from huggingface_hub import hf_hub_download\n",
20
- "from mlip_arena.models import MLIP, MLIPCalculator, ModuleMLIP"
 
 
21
  ]
22
  },
23
  {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 7,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "from huggingface_hub import hf_hub_download\n",
11
+ "from ase.calculators.calculator import Calculator\n",
12
+ "# from mlip_arena.models import MLIP, MLIPCalculator, ModuleMLIP\n",
13
+ "\n",
14
+ "from mlip_arena.models.externals import MACE_MP_Medium\n",
15
+ "\n",
16
+ "from mlip_arena.models.utils import MLIPMap, MLIPEnum"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 9,
22
+ "metadata": {},
23
+ "outputs": [
24
+ {
25
+ "data": {
26
+ "text/plain": [
27
+ "True"
28
+ ]
29
+ },
30
+ "execution_count": 9,
31
+ "metadata": {},
32
+ "output_type": "execute_result"
33
+ }
34
+ ],
35
+ "source": [
36
+ "issubclass(MLIPEnum[\"MACE-MP(M)\"].value, Calculator)"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 13,
42
  "metadata": {},
43
  "outputs": [
44
+ {
45
+ "data": {
46
+ "text/plain": [
47
+ "True"
48
+ ]
49
+ },
50
+ "execution_count": 13,
51
+ "metadata": {},
52
+ "output_type": "execute_result"
53
+ }
54
+ ],
55
+ "source": [
56
+ "isinstance(MLIPEnum[\"MACE-MP(M)\"], MLIPEnum)# in MLIPEnum"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 16,
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "data": {
66
+ "text/plain": [
67
+ "['MACE-MP(M)', 'CHGNet', 'EquiformerV2(OC22)', 'eSCN(OC20)']"
68
+ ]
69
+ },
70
+ "execution_count": 16,
71
+ "metadata": {},
72
+ "output_type": "execute_result"
73
+ }
74
+ ],
75
+ "source": [
76
+ "MLIPEnum._member_names_"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 9,
82
+ "metadata": {},
83
+ "outputs": [
84
+ {
85
+ "data": {
86
+ "text/plain": [
87
+ "{'MACE-MP(M)': mlip_arena.models.externals.MACE_MP_Medium,\n",
88
+ " 'CHGNet': mlip_arena.models.externals.CHGNet,\n",
89
+ " 'EquiformerV2(OC22)': mlip_arena.models.externals.EquiformerV2,\n",
90
+ " 'eSCN(OC20)': mlip_arena.models.externals.eSCN}"
91
+ ]
92
+ },
93
+ "execution_count": 9,
94
+ "metadata": {},
95
+ "output_type": "execute_result"
96
+ }
97
+ ],
98
+ "source": [
99
+ "MLIPMap"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 8,
105
+ "metadata": {},
106
+ "outputs": [
107
+ {
108
+ "name": "stdout",
109
+ "output_type": "stream",
110
+ "text": [
111
+ "MLIPEnum.MACE-MP(M)\n",
112
+ "MLIPEnum.CHGNet\n",
113
+ "MLIPEnum.EquiformerV2(OC22)\n",
114
+ "MLIPEnum.eSCN(OC20)\n"
115
+ ]
116
+ }
117
+ ],
118
+ "source": [
119
+ "for mlip in MLIPEnum:\n",
120
+ " print(mlip)"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 4,
126
+ "metadata": {},
127
+ "outputs": [
128
+ {
129
+ "data": {
130
+ "text/plain": [
131
+ "mlip_arena.models.externals.MACE_MP_Medium"
132
+ ]
133
+ },
134
+ "execution_count": 4,
135
+ "metadata": {},
136
+ "output_type": "execute_result"
137
+ }
138
+ ],
139
+ "source": [
140
+ "MLIPMap['MACE-MP(M)']"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 2,
146
+ "metadata": {},
147
+ "outputs": [
148
+ {
149
+ "name": "stdout",
150
+ "output_type": "stream",
151
+ "text": [
152
+ "Using Materials Project MACE for MACECalculator with /global/homes/c/cyrusyc/.cache/mace/5yyxdm76\n",
153
+ "Selected GPU cuda:0 with 40338.06 MB free memory from 1 GPUs\n",
154
+ "Default dtype float32 does not match model dtype float64, converting models to float32.\n"
155
+ ]
156
+ }
157
+ ],
158
+ "source": [
159
+ "mace_mp = MACE_MP_Medium()"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 3,
165
+ "metadata": {},
166
+ "outputs": [
167
+ {
168
+ "name": "stdout",
169
+ "output_type": "stream",
170
+ "text": [
171
+ "Select GPU cuda:0 with 40316.98 MB free memory from 1 GPUs\n",
172
+ "CHGNet v0.3.0 initialized with 412,525 parameters\n",
173
+ "CHGNet will run on cuda:0\n"
174
+ ]
175
+ },
176
  {
177
  "name": "stderr",
178
  "output_type": "stream",
179
  "text": [
180
+ "WARNING:root:Detected old config, converting to new format. Consider updating to avoid potential incompatibilities.\n",
181
+ "WARNING:root:Skipping scheduler setup. No training set found.\n"
182
  ]
183
  }
184
  ],
185
  "source": [
186
+ "from mlip_arena.models.externals import EquiformerV2, CHGNet\n",
187
+ "\n",
188
+ "chgnet = CHGNet()\n",
189
+ "\n",
190
+ "equiformer_v2 = EquiformerV2()\n"
191
  ]
192
  },
193
  {