from __future__ import annotations from pathlib import Path import yaml from ase import Atoms from fairchem.core import OCPCalculator from huggingface_hub import hf_hub_download with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f: REGISTRY = yaml.safe_load(f) class eqV2(OCPCalculator): def __init__( self, checkpoint=REGISTRY["eqV2(OMat)"]["checkpoint"], cache_dir=None, cpu=False, # TODO: cannot assign device seed=0, **kwargs, ) -> None: """ Initialize an eqV2 calculator. Parameters ---------- checkpoint : str, default="eqV2_86M_omat_mp_salex.pt" The name of the eqV2 checkpoint to use. local_cache : str, default="/tmp/ocp/" The directory to store the downloaded checkpoint. cpu : bool, default=False Whether to run the model on CPU or GPU. seed : int, default=0 The random seed for the model. Other Parameters ---------------- **kwargs Any additional keyword arguments are passed to the superclass. """ # https://huggingface.co/fairchem/OMAT24/resolve/main/eqV2_86M_omat_mp_salex.pt checkpoint_path = hf_hub_download( "fairchem/OMAT24", filename=checkpoint, revision="bf92f9671cb9d5b5c77ecb4aa8b317ff10b882ce", cache_dir=cache_dir ) super().__init__( checkpoint_path=checkpoint_path, cpu=cpu, seed=seed, **kwargs, ) class EquiformerV2(OCPCalculator): def __init__( self, checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"], # TODO: cannot assign device local_cache="/tmp/ocp/", cpu=False, seed=0, **kwargs, ) -> None: super().__init__( model_name=checkpoint, local_cache=local_cache, cpu=cpu, seed=seed, **kwargs, ) def calculate(self, atoms: Atoms, properties, system_changes) -> None: super().calculate(atoms, properties, system_changes) self.results.update( force=atoms.get_forces(), ) class EquiformerV2OC20(OCPCalculator): def __init__( self, checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"], # TODO: cannot assign device local_cache="/tmp/ocp/", cpu=False, seed=0, **kwargs, ) -> None: super().__init__( model_name=checkpoint, local_cache=local_cache, cpu=cpu, seed=seed, **kwargs, ) class eSCN(OCPCalculator): def __init__( self, checkpoint="eSCN-L6-M3-Lay20-S2EF-OC20-All+MD", # TODO: import from registry # TODO: cannot assign device local_cache="/tmp/ocp/", cpu=False, seed=0, **kwargs, ) -> None: super().__init__( model_name=checkpoint, local_cache=local_cache, cpu=cpu, seed=seed, **kwargs, ) def calculate(self, atoms: Atoms, properties, system_changes) -> None: super().calculate(atoms, properties, system_changes) self.results.update( force=atoms.get_forces(), )