Spaces:
Running
Running
from __future__ import annotations | |
from pathlib import Path | |
import yaml | |
from ase import Atoms | |
from fairchem.core import OCPCalculator | |
from huggingface_hub import hf_hub_download | |
with open(Path(__file__).parents[1] / "registry.yaml", encoding="utf-8") as f: | |
REGISTRY = yaml.safe_load(f) | |
class eqV2(OCPCalculator): | |
def __init__( | |
self, | |
checkpoint=REGISTRY["eqV2(OMat)"]["checkpoint"], | |
cache_dir=None, | |
cpu=False, # TODO: cannot assign device | |
seed=0, | |
**kwargs, | |
) -> None: | |
""" | |
Initialize an eqV2 calculator. | |
Parameters | |
---------- | |
checkpoint : str, default="eqV2_86M_omat_mp_salex.pt" | |
The name of the eqV2 checkpoint to use. | |
local_cache : str, default="/tmp/ocp/" | |
The directory to store the downloaded checkpoint. | |
cpu : bool, default=False | |
Whether to run the model on CPU or GPU. | |
seed : int, default=0 | |
The random seed for the model. | |
Other Parameters | |
---------------- | |
**kwargs | |
Any additional keyword arguments are passed to the superclass. | |
""" | |
# https://huggingface.co/fairchem/OMAT24/resolve/main/eqV2_86M_omat_mp_salex.pt | |
checkpoint_path = hf_hub_download( | |
"fairchem/OMAT24", | |
filename=checkpoint, | |
revision="bf92f9671cb9d5b5c77ecb4aa8b317ff10b882ce", | |
cache_dir=cache_dir | |
) | |
super().__init__( | |
checkpoint_path=checkpoint_path, | |
cpu=cpu, | |
seed=seed, | |
**kwargs, | |
) | |
class EquiformerV2(OCPCalculator): | |
def __init__( | |
self, | |
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"], | |
# TODO: cannot assign device | |
local_cache="/tmp/ocp/", | |
cpu=False, | |
seed=0, | |
**kwargs, | |
) -> None: | |
super().__init__( | |
model_name=checkpoint, | |
local_cache=local_cache, | |
cpu=cpu, | |
seed=seed, | |
**kwargs, | |
) | |
def calculate(self, atoms: Atoms, properties, system_changes) -> None: | |
super().calculate(atoms, properties, system_changes) | |
self.results.update( | |
force=atoms.get_forces(), | |
) | |
class EquiformerV2OC20(OCPCalculator): | |
def __init__( | |
self, | |
checkpoint=REGISTRY["EquiformerV2(OC22)"]["checkpoint"], | |
# TODO: cannot assign device | |
local_cache="/tmp/ocp/", | |
cpu=False, | |
seed=0, | |
**kwargs, | |
) -> None: | |
super().__init__( | |
model_name=checkpoint, | |
local_cache=local_cache, | |
cpu=cpu, | |
seed=seed, | |
**kwargs, | |
) | |
class eSCN(OCPCalculator): | |
def __init__( | |
self, | |
checkpoint="eSCN-L6-M3-Lay20-S2EF-OC20-All+MD", # TODO: import from registry | |
# TODO: cannot assign device | |
local_cache="/tmp/ocp/", | |
cpu=False, | |
seed=0, | |
**kwargs, | |
) -> None: | |
super().__init__( | |
model_name=checkpoint, | |
local_cache=local_cache, | |
cpu=cpu, | |
seed=seed, | |
**kwargs, | |
) | |
def calculate(self, atoms: Atoms, properties, system_changes) -> None: | |
super().calculate(atoms, properties, system_changes) | |
self.results.update( | |
force=atoms.get_forces(), | |
) | |