cyrusyc commited on
Commit
ee784cf
1 Parent(s): 2d8bda8

updates before modification

Browse files
.gitignore CHANGED
@@ -2,6 +2,7 @@
2
  __pycache__/
3
  *.py[cod]
4
  *$py.class
 
5
 
6
  # C extensions
7
  *.so
 
2
  __pycache__/
3
  *.py[cod]
4
  *$py.class
5
+ tests/
6
 
7
  # C extensions
8
  *.so
mlip_arena/models/__init__.py CHANGED
@@ -12,11 +12,6 @@ from torch_geometric.data import Data
12
  with open(os.path.join(os.path.dirname(__file__), "registry.yaml")) as f:
13
  REGISTRY = yaml.load(f, Loader=yaml.FullLoader)
14
 
15
- # class MLIPEnum(enum.Enum):
16
- # for model, metadata in REGISTRY.items():
17
- # model_class = getattr(importlib.import_module(model["module"]), model)
18
- # self.setattr(model, model_class)
19
-
20
 
21
  class MLIP(
22
  nn.Module,
@@ -30,7 +25,7 @@ class MLIP(
30
  class ModuleMLIP(MLIP):
31
  def __init__(self, model: nn.Module, *args, **kwargs) -> None:
32
  super().__init__(*args, **kwargs)
33
- self.register_module("model", model)
34
 
35
  def forward(self, x):
36
  print("Forwarding...")
@@ -41,15 +36,12 @@ class ModuleMLIP(MLIP):
41
 
42
  class MLIPCalculator(Calculator):
43
  name: str
44
- device: torch.device
45
- model: MLIP
46
  implemented_properties: list[str] = ["energy", "forces", "stress"]
47
 
48
  def __init__(
49
  self,
50
- # PyTorch
51
- model_path: str | Path,
52
- device: torch.device | None = None,
53
  # ASE Calculator
54
  restart=None,
55
  atoms=None,
 
12
  with open(os.path.join(os.path.dirname(__file__), "registry.yaml")) as f:
13
  REGISTRY = yaml.load(f, Loader=yaml.FullLoader)
14
 
 
 
 
 
 
15
 
16
  class MLIP(
17
  nn.Module,
 
25
  class ModuleMLIP(MLIP):
26
  def __init__(self, model: nn.Module, *args, **kwargs) -> None:
27
  super().__init__(*args, **kwargs)
28
+ self.add_module("model", model)
29
 
30
  def forward(self, x):
31
  print("Forwarding...")
 
36
 
37
  class MLIPCalculator(Calculator):
38
  name: str
39
+ # device: torch.device
40
+ # model: MLIP
41
  implemented_properties: list[str] = ["energy", "forces", "stress"]
42
 
43
  def __init__(
44
  self,
 
 
 
45
  # ASE Calculator
46
  restart=None,
47
  atoms=None,
mlip_arena/models/mace.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import torch
2
  from ase import Atoms
3
  from ase.calculators.calculator import all_changes
@@ -16,48 +19,33 @@ class MACE_MP_Medium(MLIPCalculator):
16
  directory=".",
17
  **kwargs,
18
  ):
19
- # Download the pytorch model from huggingface to local and load it
20
- # NOTE: this is not the ideal way to load the model, but it is the simplest
21
- # way to do it for now. Ideally, if the model is the subclass of PyTorchModelHubMixin,
22
- # we should be able to load it directly from the hub or local using MLIP class.
23
  fpath = hf_hub_download(
24
  repo_id="cyrusyc/mace-universal",
25
  subfolder="pretrained",
26
  filename="2023-12-12-mace-128-L1_epoch-199.model",
27
- revision=None, # TODO: Add revision
28
  )
29
- # module = ModuleMLIP(torch.load(fpath, map_location="cpu"))
30
- print(torch.load(fpath, map_location="cpu"))
31
- repo_id = f"atomind/{self.__class__.__name__}".replace("_", "-")
32
- # module.save_pretrained(
33
- # save_directory=self.__class__.__name__,
34
- # repo_id=repo_id,
35
- # push_to_hub=True,
36
- # )
37
 
38
- super().__init__(
39
- model_path=repo_id,
40
- device=device,
41
- restart=restart,
42
- atoms=atoms,
43
- directory=directory,
44
- **kwargs,
45
  )
46
 
47
- # self.name: str = self.__class__.__name__
48
- # self.device = device or torch.device(
49
- # "cuda" if torch.cuda.is_available() else "cpu"
50
- # )
51
- # self.model: MLIP = ModuleMLIP.from_pretrained(repo_id, map_location=self.device)
52
- # self.implemented_properties = ["energy", "forces", "stress"]
53
 
54
- self.display = "MACE-MP-0 (medium)"
55
- self.version = "1.0.0"
56
- self.implemented_properties = [
57
- "energy",
58
- "forces",
59
- "stress",
60
- ]
61
 
62
  def calculate(
63
  self, atoms: Atoms, properties: list[str], system_changes: list = all_changes
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
  import torch
5
  from ase import Atoms
6
  from ase.calculators.calculator import all_changes
 
19
  directory=".",
20
  **kwargs,
21
  ):
22
+ super().__init__(restart=restart, atoms=atoms, directory=directory, **kwargs)
23
+
24
+ self.name: str = self.__class__.__name__
25
+
26
  fpath = hf_hub_download(
27
  repo_id="cyrusyc/mace-universal",
28
  subfolder="pretrained",
29
  filename="2023-12-12-mace-128-L1_epoch-199.model",
30
+ revision="main",
31
  )
 
 
 
 
 
 
 
 
32
 
33
+ self.device = device or torch.device(
34
+ "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
35
  )
36
 
37
+ self.model = torch.load(fpath, map_location=self.device)
38
+
39
+ self.implemented_properties = ["energy", "forces", "stress"]
40
+
41
+ # repo_id = f"atomind/{self.__class__.__name__}".lower().replace("_", "-")
 
42
 
43
+ # model = ModuleMLIP(model=model)
44
+ # model.save_pretrained(
45
+ # self.__class__.__name__.lower().replace("_", "-"),
46
+ # repo_id=repo_id,
47
+ # push_to_hub=True,
48
+ # )
 
49
 
50
  def calculate(
51
  self, atoms: Atoms, properties: list[str], system_changes: list = all_changes
mlip_arena/models/utils.py CHANGED
@@ -1,15 +1,91 @@
 
 
1
  import importlib
2
- import os
3
  from enum import Enum
 
 
 
 
 
 
4
 
5
  from mlip_arena.models import REGISTRY
6
 
7
- MLIPEnum = Enum(
8
- "MLIPEnum",
9
- {
10
- model: getattr(
11
- importlib.import_module(f"{__package__}.{metadata['module']}"), model
12
- )
13
- for model, metadata in REGISTRY.items()
14
- },
15
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for MLIP models."""
2
+
3
  import importlib
 
4
  from enum import Enum
5
+ from typing import Any
6
+
7
+ import torch
8
+ from ase.calculators.calculator import Calculator
9
+ from ase.calculators.mixing import SumCalculator
10
+ from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
11
 
12
  from mlip_arena.models import REGISTRY
13
 
14
+ MLIPMap = {
15
+ model: getattr(
16
+ importlib.import_module(f"{__package__}.{metadata['module']}"), model
17
+ )
18
+ for model, metadata in REGISTRY.items()
19
+ }
20
+
21
+
22
+ class EXTMLIPEnum(Enum):
23
+ """
24
+ Enumeration class for EXTMLIP models.
25
+
26
+ Attributes:
27
+ M3GNet (str): M3GNet model.
28
+ CHGNet (str): CHGNet model.
29
+ MACE (str): MACE model.
30
+ """
31
+
32
+ M3GNet = "M3GNet"
33
+ CHGNet = "CHGNet"
34
+ MACE = "MACE"
35
+
36
+
37
+ def get_freer_device() -> torch.device:
38
+ """Get the GPU with the most free memory.
39
+
40
+ Returns:
41
+ torch.device: The selected GPU device.
42
+
43
+ Raises:
44
+ ValueError: If no GPU is available.
45
+ """
46
+ device_count = torch.cuda.device_count()
47
+ if device_count == 0:
48
+ print("No GPU available. Using CPU.")
49
+ return torch.device("cpu")
50
+
51
+ mem_free = [
52
+ torch.cuda.get_device_properties(i).total_memory
53
+ - torch.cuda.memory_allocated(i)
54
+ for i in range(device_count)
55
+ ]
56
+
57
+ free_gpu_index = mem_free.index(max(mem_free))
58
+
59
+ print(
60
+ f"Selected GPU {free_gpu_index} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
61
+ )
62
+
63
+ return torch.device(f"cuda:{free_gpu_index}")
64
+
65
+
66
+ def external_ase_calculator(name: EXTMLIPEnum, **kwargs: Any) -> Calculator:
67
+ """Construct an ASE calculator from an external third-party MLIP packages"""
68
+
69
+ calculator = None
70
+ device = get_freer_device()
71
+
72
+ if name == EXTMLIPEnum.MACE:
73
+ from mace.calculators import mace_mp
74
+
75
+ calculator = mace_mp(device=str(device), **kwargs)
76
+
77
+ elif name == EXTMLIPEnum.CHGNet:
78
+ from chgnet.model.dynamics import CHGNetCalculator
79
+
80
+ calculator = CHGNetCalculator(use_device=str(device), **kwargs)
81
+
82
+ elif name == EXTMLIPEnum.M3GNet:
83
+ import matgl
84
+ from matgl.ext.ase import PESCalculator
85
+
86
+ potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
87
+ calculator = PESCalculator(potential, **kwargs)
88
+
89
+ calculator.__setattr__("name", name.value)
90
+
91
+ return calculator
mlip_arena/tasks/diatomics.py DELETED
@@ -1,123 +0,0 @@
1
- import covalent as ct
2
- import numpy as np
3
- import pandas as pd
4
- import torch
5
- from ase import Atoms
6
- from ase.calculators.calculator import Calculator
7
- from ase.data import chemical_symbols
8
- from matplotlib import pyplot as plt
9
-
10
- from mlip_arena.models import MLIPCalculator
11
-
12
- device = torch.device("cuda")
13
-
14
-
15
- @ct.electron
16
- def calculate_single_diatomic(
17
- calculator: MLIPCalculator | Calculator,
18
- atom1: str,
19
- atom2: str,
20
- rmin: float = 0.1,
21
- rmax: float = 6.5,
22
- npts: int = int(1e3),
23
- ):
24
- a = 2 * rmax
25
-
26
- rs = np.linspace(rmin, rmax, npts)
27
- e = np.zeros_like(rs)
28
- f = np.zeros_like(rs)
29
-
30
- da = atom1 + atom2
31
-
32
- for i, r in enumerate(rs):
33
-
34
- positions = [
35
- [0, 0, 0],
36
- [r, 0, 0],
37
- ]
38
-
39
- # Create the unit cell with two atoms
40
- atoms = Atoms(da, positions=positions, cell=[a, a, a])
41
-
42
- atoms.calc = calculator
43
-
44
- e[i] = atoms.get_potential_energy()
45
- f[i] = np.inner(np.array([1, 0, 0]), atoms.get_forces()[1])
46
-
47
- return rs, e, f, da
48
-
49
-
50
- @ct.lattice
51
- def calculate_homonuclear_diatomics(calculator: MLIPCalculator | Calculator):
52
-
53
- chemical_symbols.remove("X")
54
-
55
- results = {}
56
-
57
- for atom in chemical_symbols:
58
- rs, e, f, da = calculate_single_diatomic(calculator, atom, atom)
59
- results[da] = {"r": rs, "E": e, "F": f}
60
-
61
- return results
62
-
63
-
64
- # with plt.style.context("default"):
65
-
66
- # SMALL_SIZE = 6
67
- # MEDIUM_SIZE = 8
68
- # LARGE_SIZE = 10
69
-
70
- # LINE_WIDTH = 1
71
-
72
- # plt.rcParams.update(
73
- # {
74
- # "pgf.texsystem": "pdflatex",
75
- # "font.family": "sans-serif",
76
- # "text.usetex": True,
77
- # "pgf.rcfonts": True,
78
- # "figure.constrained_layout.use": True,
79
- # "axes.labelsize": MEDIUM_SIZE,
80
- # "axes.titlesize": MEDIUM_SIZE,
81
- # "legend.frameon": False,
82
- # "legend.fontsize": MEDIUM_SIZE,
83
- # "legend.loc": "best",
84
- # "lines.linewidth": LINE_WIDTH,
85
- # "xtick.labelsize": SMALL_SIZE,
86
- # "ytick.labelsize": SMALL_SIZE,
87
- # }
88
- # )
89
-
90
- # fig, ax = plt.subplots(layout="constrained", figsize=(3, 2), dpi=300)
91
-
92
- # color = "tab:red"
93
- # ax.plot(rs, e, color=color, zorder=1)
94
-
95
- # ax.axhline(ls="--", color=color, alpha=0.5, lw=0.5 * LINE_WIDTH)
96
-
97
- # ylo, yhi = ax.get_ylim()
98
- # ax.set(xlabel=r"r [$\AA]$", ylim=(max(-7, ylo), min(5, yhi)))
99
- # ax.set_ylabel(ylabel="E [eV]", color=color)
100
- # ax.tick_params(axis="y", labelcolor=color)
101
- # ax.text(0.8, 0.85, da, fontsize=LARGE_SIZE, transform=ax.transAxes)
102
-
103
- # color = "tab:blue"
104
-
105
- # at = ax.twinx()
106
- # at.plot(rs, f, color=color, zorder=0, lw=0.5 * LINE_WIDTH)
107
-
108
- # at.axhline(ls="--", color=color, alpha=0.5, lw=0.5 * LINE_WIDTH)
109
-
110
- # ylo, yhi = at.get_ylim()
111
- # at.set(
112
- # xlabel=r"r [$\AA]$",
113
- # ylim=(max(-20, ylo), min(20, yhi)),
114
- # )
115
- # at.set_ylabel(ylabel="F [eV/$\AA$]", color=color)
116
- # at.tick_params(axis="y", labelcolor=color)
117
-
118
- # plt.show()
119
-
120
-
121
- if __name__ == "__main__":
122
-
123
- local = ct.executor.LocalExecutor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/hf_hub.ipynb CHANGED
@@ -2,9 +2,18 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 4,
6
  "metadata": {},
7
- "outputs": [],
 
 
 
 
 
 
 
 
 
8
  "source": [
9
  "import torch\n",
10
  "from huggingface_hub import hf_hub_download\n",
@@ -13,7 +22,7 @@
13
  },
14
  {
15
  "cell_type": "code",
16
- "execution_count": 3,
17
  "metadata": {},
18
  "outputs": [],
19
  "source": [
@@ -30,7 +39,7 @@
30
  },
31
  {
32
  "cell_type": "code",
33
- "execution_count": 5,
34
  "metadata": {},
35
  "outputs": [],
36
  "source": [
@@ -39,23 +48,16 @@
39
  },
40
  {
41
  "cell_type": "code",
42
- "execution_count": 12,
43
  "metadata": {},
44
  "outputs": [
45
- {
46
- "name": "stderr",
47
- "output_type": "stream",
48
- "text": [
49
- "model.safetensors: 100%|██████████| 44.2M/44.2M [00:02<00:00, 20.6MB/s]\n"
50
- ]
51
- },
52
  {
53
  "data": {
54
  "text/plain": [
55
- "CommitInfo(commit_url='https://huggingface.co/atomind/mace-mp-medium/commit/ef94f6bd9c7167bb28d594d5a3e7a5fbfbda2acb', commit_message='Push model using huggingface_hub.', commit_description='', oid='ef94f6bd9c7167bb28d594d5a3e7a5fbfbda2acb', pr_url=None, pr_revision=None, pr_num=None)"
56
  ]
57
  },
58
- "execution_count": 12,
59
  "metadata": {},
60
  "output_type": "execute_result"
61
  }
@@ -68,6 +70,214 @@
68
  ")"
69
  ]
70
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  {
72
  "cell_type": "code",
73
  "execution_count": null,
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
  "source": [
18
  "import torch\n",
19
  "from huggingface_hub import hf_hub_download\n",
 
22
  },
23
  {
24
  "cell_type": "code",
25
+ "execution_count": 2,
26
  "metadata": {},
27
  "outputs": [],
28
  "source": [
 
39
  },
40
  {
41
  "cell_type": "code",
42
+ "execution_count": 3,
43
  "metadata": {},
44
  "outputs": [],
45
  "source": [
 
48
  },
49
  {
50
  "cell_type": "code",
51
+ "execution_count": 4,
52
  "metadata": {},
53
  "outputs": [
 
 
 
 
 
 
 
54
  {
55
  "data": {
56
  "text/plain": [
57
+ "CommitInfo(commit_url='https://huggingface.co/atomind/mace-mp-medium/commit/eb12c5387b9e655d83a4e2e10c0f0779c3745227', commit_message='Push model using huggingface_hub.', commit_description='', oid='eb12c5387b9e655d83a4e2e10c0f0779c3745227', pr_url=None, pr_revision=None, pr_num=None)"
58
  ]
59
  },
60
+ "execution_count": 4,
61
  "metadata": {},
62
  "output_type": "execute_result"
63
  }
 
70
  ")"
71
  ]
72
  },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 1,
76
+ "metadata": {},
77
+ "outputs": [
78
+ {
79
+ "name": "stderr",
80
+ "output_type": "stream",
81
+ "text": [
82
+ "/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
83
+ " from .autonotebook import tqdm as notebook_tqdm\n"
84
+ ]
85
+ }
86
+ ],
87
+ "source": [
88
+ "\n",
89
+ "from mlip_arena.models.mace import MACE_MP_Medium\n",
90
+ "import torch\n",
91
+ "\n",
92
+ "calc = MACE_MP_Medium(device=torch.device(\"cuda\"))"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": 2,
98
+ "metadata": {},
99
+ "outputs": [
100
+ {
101
+ "data": {
102
+ "text/plain": [
103
+ "ScaleShiftMACE(\n",
104
+ " (node_embedding): LinearNodeEmbeddingBlock(\n",
105
+ " (linear): Linear(89x0e -> 128x0e | 11392 weights)\n",
106
+ " )\n",
107
+ " (radial_embedding): RadialEmbeddingBlock(\n",
108
+ " (bessel_fn): BesselBasis(r_max=6.0, num_basis=10, trainable=False)\n",
109
+ " (cutoff_fn): PolynomialCutoff(p=5.0, r_max=6.0)\n",
110
+ " )\n",
111
+ " (spherical_harmonics): SphericalHarmonics()\n",
112
+ " (atomic_energies_fn): AtomicEnergiesBlock(energies=[-3.6672, -1.3321, -3.4821, -4.7367, -7.7249, -8.4056, -7.3601, -7.2846, -4.8965, 0.0000, -2.7594, -2.8140, -4.8469, -7.6948, -6.9633, -4.6726, -2.8117, -0.0626, -2.6176, -5.3905, -7.8858, -10.2684, -8.6651, -9.2331, -8.3050, -7.0490, -5.5774, -5.1727, -3.2521, -1.2902, -3.5271, -4.7085, -3.9765, -3.8862, -2.5185, 6.7669, -2.5635, -4.9380, -10.1498, -11.8469, -12.1389, -8.7917, -8.7869, -7.7809, -6.8500, -4.8910, -2.0634, -0.6396, -2.7887, -3.8186, -3.5871, -2.8804, -1.6356, 9.8467, -2.7653, -4.9910, -8.9337, -8.7356, -8.0190, -8.2515, -7.5917, -8.1697, -13.5927, -18.5175, -7.6474, -8.1230, -7.6078, -6.8503, -7.8269, -3.5848, -7.4554, -12.7963, -14.1081, -9.3549, -11.3875, -9.6219, -7.3244, -5.3047, -2.3801, 0.2495, -2.3240, -3.7300, -3.4388, -5.0629, -11.0246, -12.2656, -13.8556, -14.9331, -15.2828])\n",
113
+ " (interactions): ModuleList(\n",
114
+ " (0): RealAgnosticResidualInteractionBlock(\n",
115
+ " (linear_up): Linear(128x0e -> 128x0e | 16384 weights)\n",
116
+ " (conv_tp): TensorProduct(128x0e x 1x0e+1x1o+1x2e+1x3o -> 128x0e+128x1o+128x2e+128x3o | 512 paths | 512 weights)\n",
117
+ " (conv_tp_weights): FullyConnectedNet[10, 64, 64, 64, 512]\n",
118
+ " (linear): Linear(128x0e+128x1o+128x2e+128x3o -> 128x0e+128x1o+128x2e+128x3o | 65536 weights)\n",
119
+ " (skip_tp): FullyConnectedTensorProduct(128x0e x 89x0e -> 128x0e+128x1o | 1458176 paths | 1458176 weights)\n",
120
+ " (reshape): reshape_irreps()\n",
121
+ " )\n",
122
+ " (1): RealAgnosticResidualInteractionBlock(\n",
123
+ " (linear_up): Linear(128x0e+128x1o -> 128x0e+128x1o | 32768 weights)\n",
124
+ " (conv_tp): TensorProduct(128x0e+128x1o x 1x0e+1x1o+1x2e+1x3o -> 256x0e+384x1o+384x2e+256x3o | 1280 paths | 1280 weights)\n",
125
+ " (conv_tp_weights): FullyConnectedNet[10, 64, 64, 64, 1280]\n",
126
+ " (linear): Linear(256x0e+384x1o+384x2e+256x3o -> 128x0e+128x1o+128x2e+128x3o | 163840 weights)\n",
127
+ " (skip_tp): FullyConnectedTensorProduct(128x0e+128x1o x 89x0e -> 128x0e | 1458176 paths | 1458176 weights)\n",
128
+ " (reshape): reshape_irreps()\n",
129
+ " )\n",
130
+ " )\n",
131
+ " (products): ModuleList(\n",
132
+ " (0): EquivariantProductBasisBlock(\n",
133
+ " (symmetric_contractions): SymmetricContraction(\n",
134
+ " (contractions): ModuleList(\n",
135
+ " (0): Contraction(\n",
136
+ " (contractions_weighting): ModuleList(\n",
137
+ " (0-1): 2 x GraphModule()\n",
138
+ " )\n",
139
+ " (contractions_features): ModuleList(\n",
140
+ " (0-1): 2 x GraphModule()\n",
141
+ " )\n",
142
+ " (weights): ParameterList(\n",
143
+ " (0): Parameter containing: [torch.float64 of size 89x4x128 (cuda:0)]\n",
144
+ " (1): Parameter containing: [torch.float64 of size 89x1x128 (cuda:0)]\n",
145
+ " )\n",
146
+ " (graph_opt_main): GraphModule()\n",
147
+ " )\n",
148
+ " (1): Contraction(\n",
149
+ " (contractions_weighting): ModuleList(\n",
150
+ " (0-1): 2 x GraphModule()\n",
151
+ " )\n",
152
+ " (contractions_features): ModuleList(\n",
153
+ " (0-1): 2 x GraphModule()\n",
154
+ " )\n",
155
+ " (weights): ParameterList(\n",
156
+ " (0): Parameter containing: [torch.float64 of size 89x6x128 (cuda:0)]\n",
157
+ " (1): Parameter containing: [torch.float64 of size 89x1x128 (cuda:0)]\n",
158
+ " )\n",
159
+ " (graph_opt_main): GraphModule()\n",
160
+ " )\n",
161
+ " )\n",
162
+ " )\n",
163
+ " (linear): Linear(128x0e+128x1o -> 128x0e+128x1o | 32768 weights)\n",
164
+ " )\n",
165
+ " (1): EquivariantProductBasisBlock(\n",
166
+ " (symmetric_contractions): SymmetricContraction(\n",
167
+ " (contractions): ModuleList(\n",
168
+ " (0): Contraction(\n",
169
+ " (contractions_weighting): ModuleList(\n",
170
+ " (0-1): 2 x GraphModule()\n",
171
+ " )\n",
172
+ " (contractions_features): ModuleList(\n",
173
+ " (0-1): 2 x GraphModule()\n",
174
+ " )\n",
175
+ " (weights): ParameterList(\n",
176
+ " (0): Parameter containing: [torch.float64 of size 89x4x128 (cuda:0)]\n",
177
+ " (1): Parameter containing: [torch.float64 of size 89x1x128 (cuda:0)]\n",
178
+ " )\n",
179
+ " (graph_opt_main): GraphModule()\n",
180
+ " )\n",
181
+ " )\n",
182
+ " )\n",
183
+ " (linear): Linear(128x0e -> 128x0e | 16384 weights)\n",
184
+ " )\n",
185
+ " )\n",
186
+ " (readouts): ModuleList(\n",
187
+ " (0): LinearReadoutBlock(\n",
188
+ " (linear): Linear(128x0e+128x1o -> 1x0e | 128 weights)\n",
189
+ " )\n",
190
+ " (1): NonLinearReadoutBlock(\n",
191
+ " (linear_1): Linear(128x0e -> 16x0e | 2048 weights)\n",
192
+ " (non_linearity): Activation [x] (16x0e -> 16x0e)\n",
193
+ " (linear_2): Linear(16x0e -> 1x0e | 16 weights)\n",
194
+ " )\n",
195
+ " )\n",
196
+ " (scale_shift): ScaleShiftBlock(scale=0.804154, shift=0.164097)\n",
197
+ ")"
198
+ ]
199
+ },
200
+ "execution_count": 2,
201
+ "metadata": {},
202
+ "output_type": "execute_result"
203
+ }
204
+ ],
205
+ "source": [
206
+ "calc.model\n"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": 2,
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": [
215
+ "from mlip_arena.models import MLIP\n",
216
+ "\n",
217
+ "model = MLIP.from_pretrained(\"atomind/mace-mp-medium\", map_location=\"cuda\", revision=\"main\")"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 5,
223
+ "metadata": {},
224
+ "outputs": [
225
+ {
226
+ "data": {
227
+ "text/plain": [
228
+ "<generator object Module.modules at 0x7ff33915f920>"
229
+ ]
230
+ },
231
+ "execution_count": 5,
232
+ "metadata": {},
233
+ "output_type": "execute_result"
234
+ }
235
+ ],
236
+ "source": [
237
+ "model.modules()"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": 8,
243
+ "metadata": {},
244
+ "outputs": [
245
+ {
246
+ "ename": "AttributeError",
247
+ "evalue": "MLIP has no attribute `model`",
248
+ "output_type": "error",
249
+ "traceback": [
250
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
251
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
252
+ "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_submodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodel\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
253
+ "File \u001b[0;32m/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/torch/nn/modules/module.py:681\u001b[0m, in \u001b[0;36mModule.get_submodule\u001b[0;34m(self, target)\u001b[0m\n\u001b[1;32m 678\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m atoms:\n\u001b[1;32m 680\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(mod, item):\n\u001b[0;32m--> 681\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(mod\u001b[38;5;241m.\u001b[39m_get_name() \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m has no \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 682\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattribute `\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m item \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 684\u001b[0m mod \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(mod, item)\n\u001b[1;32m 686\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(mod, torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mModule):\n",
254
+ "\u001b[0;31mAttributeError\u001b[0m: MLIP has no attribute `model`"
255
+ ]
256
+ }
257
+ ],
258
+ "source": [
259
+ "model.get_submodule(\"model\")"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": null,
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "for name, param in model.named_parameters():\n",
269
+ " print(name, param.data)"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "metadata": {},
276
+ "outputs": [],
277
+ "source": [
278
+ "print(module)"
279
+ ]
280
+ },
281
  {
282
  "cell_type": "code",
283
  "execution_count": null,
tests/oxygen_diatomics.ipynb CHANGED
The diff for this file is too large to render. See raw diff