File size: 2,698 Bytes
b3722a8
49d0cfc
b3722a8
d390139
b3722a8
e473a42
 
49d0cfc
d390139
 
 
b3722a8
 
d390139
49d0cfc
 
 
 
 
 
 
 
 
 
2d8bda8
 
 
ee784cf
2d8bda8
 
 
 
 
 
 
 
49d0cfc
2d8bda8
ee784cf
 
2d8bda8
 
49d0cfc
 
2d8bda8
 
 
 
 
49d0cfc
2d8bda8
 
 
 
 
 
 
d390139
49d0cfc
 
 
d390139
 
 
 
 
 
 
49d0cfc
d390139
49d0cfc
 
 
d390139
49d0cfc
d390139
49d0cfc
d390139
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
from pathlib import Path

import torch
import yaml
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torch_geometric.data import Data

with open(os.path.join(os.path.dirname(__file__), "registry.yaml")) as f:
    REGISTRY = yaml.load(f, Loader=yaml.FullLoader)


class MLIP(
    nn.Module,
    PyTorchModelHubMixin,
    tags=["atomistic-simulation", "MLIP"],
):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)


class ModuleMLIP(MLIP):
    def __init__(self, model: nn.Module, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.add_module("model", model)

    def forward(self, x):
        print("Forwarding...")
        out = self.model(x)
        print("Forwarded!")
        return out


class MLIPCalculator(Calculator):
    name: str
    # device: torch.device
    # model: MLIP
    implemented_properties: list[str] = ["energy", "forces", "stress"]

    def __init__(
        self,
        # ASE Calculator
        restart=None,
        atoms=None,
        directory=".",
        **kwargs,
    ):
        super().__init__(restart=restart, atoms=atoms, directory=directory, **kwargs)
        # self.name: str = self.__class__.__name__
        # self.device = device or torch.device(
        #     "cuda" if torch.cuda.is_available() else "cpu"
        # )
        # self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
        # self.implemented_properties = ["energy", "forces", "stress"]

    def calculate(
        self, atoms: Atoms, properties: list[str], system_changes: list = all_changes
    ):
        """Calculate energies and forces for the given Atoms object"""
        super().calculate(atoms, properties, system_changes)

        output = self.forward(atoms)

        self.results = {}
        if "energy" in properties:
            self.results["energy"] = output["energy"].squeeze().item()
        if "forces" in properties:
            self.results["forces"] = output["forces"].squeeze().cpu().detach().numpy()
        if "stress" in properties:
            self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()

    def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
        """Implement data conversion, graph creation, and model forward pass

        Example implementation:
        1. Use `ase.neighborlist.NeighborList` to get neighbor list
        2. Create `torch_geometric.data.Data` object and copy the data
        3. Pass the `Data` object to the model and return the output

        """
        raise NotImplementedError