File size: 3,051 Bytes
3b3aaa9
 
52c1bfb
 
49d0cfc
b3722a8
d390139
b3722a8
e473a42
 
49d0cfc
d390139
d72faca
 
d390139
3b3aaa9
056d8d3
49d0cfc
52c1bfb
 
 
 
 
 
 
 
 
 
 
0d1ce35
49d0cfc
 
 
 
 
7cbf186
 
 
2d8bda8
 
7cbf186
2d8bda8
0d1ce35
7cbf186
2d8bda8
 
 
49d0cfc
 
7cbf186
2d8bda8
 
 
 
7cbf186
49d0cfc
7cbf186
0d1ce35
 
 
7cbf186
2d8bda8
 
 
 
 
 
d390139
49d0cfc
0d1ce35
 
 
 
49d0cfc
d390139
 
 
 
 
 
 
49d0cfc
d390139
49d0cfc
 
 
d390139
49d0cfc
d390139
49d0cfc
d390139
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import annotations

import importlib
from enum import Enum
from pathlib import Path

import torch
import yaml
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

# from torch_geometric.data import Data

with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
    REGISTRY = yaml.safe_load(f)

MLIPMap = {}

for model, metadata in REGISTRY.items():
    try:
        module = importlib.import_module(f"{__package__}.{metadata['module']}.{metadata['family']}")
        MLIPMap[model] = getattr(module, metadata["class"])
    except ModuleNotFoundError as e:
        print(e)
        continue

MLIPEnum = Enum("MLIPEnum", MLIPMap)

class MLIP(
    nn.Module,
    PyTorchModelHubMixin,
    tags=["atomistic-simulation", "MLIP"],
):
    def __init__(self, model: nn.Module) -> None:
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)


class MLIPCalculator(MLIP, Calculator):
    name: str
    implemented_properties: list[str] = ["energy", "forces", "stress"]

    def __init__(
        self,
        model,
        # ASE Calculator
        restart=None,
        atoms=None,
        directory=".",
        calculator_kwargs: dict = {},
    ):
        MLIP.__init__(self, model=model)  # Initialize MLIP part
        Calculator.__init__(
            self, restart=restart, atoms=atoms, directory=directory, **calculator_kwargs
        )  # Initialize ASE Calculator part
        # Additional initialization if needed
        # self.name: str = self.__class__.__name__
        # self.device = device or torch.device(
        #     "cuda" if torch.cuda.is_available() else "cpu"
        # )
        # self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
        # self.implemented_properties = ["energy", "forces", "stress"]

    def calculate(
        self,
        atoms: Atoms,
        properties: list[str],
        system_changes: list = all_changes,
    ):
        """Calculate energies and forces for the given Atoms object"""
        super().calculate(atoms, properties, system_changes)

        output = self.forward(atoms)

        self.results = {}
        if "energy" in properties:
            self.results["energy"] = output["energy"].squeeze().item()
        if "forces" in properties:
            self.results["forces"] = output["forces"].squeeze().cpu().detach().numpy()
        if "stress" in properties:
            self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()

    def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
        """Implement data conversion, graph creation, and model forward pass

        Example implementation:
        1. Use `ase.neighborlist.NeighborList` to get neighbor list
        2. Create `torch_geometric.data.Data` object and copy the data
        3. Pass the `Data` object to the model and return the output

        """
        raise NotImplementedError