Yuan (Cyrus) Chiang
Add calculator test (#6)
1c7cd6c unverified
raw
history blame
9.08 kB
from __future__ import annotations
import os
from pathlib import Path
from typing import Literal
import matgl
import requests
import torch
from alignn.ff.ff import AlignnAtomwiseCalculator, get_figshare_model_ff, default_path
from ase import Atoms
from chgnet.model.dynamics import CHGNetCalculator
from chgnet.model.model import CHGNet as CHGNetModel
from fairchem.core import OCPCalculator
from mace.calculators import MACECalculator
from matgl.ext.ase import PESCalculator
from orb_models.forcefield import pretrained
from orb_models.forcefield.calculator import ORBCalculator
from sevenn.sevennet_calculator import SevenNetCalculator
# Avoid circular import
def get_freer_device() -> torch.device:
"""Get the GPU with the most free memory, or use MPS if available.
s
Returns:
torch.device: The selected GPU device or MPS.
Raises:
ValueError: If no GPU or MPS is available.
"""
device_count = torch.cuda.device_count()
if device_count > 0:
# If CUDA GPUs are available, select the one with the most free memory
mem_free = [
torch.cuda.get_device_properties(i).total_memory
- torch.cuda.memory_allocated(i)
for i in range(device_count)
]
free_gpu_index = mem_free.index(max(mem_free))
device = torch.device(f"cuda:{free_gpu_index}")
print(
f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
)
elif torch.backends.mps.is_available():
# If no CUDA GPUs are available but MPS is, use MPS
print("No GPU available. Using MPS.")
device = torch.device("mps")
else:
# Fallback to CPU if neither CUDA GPUs nor MPS are available
print("No GPU or MPS available. Using CPU.")
device = torch.device("cpu")
return device
class MACE_MP_Medium(MACECalculator):
def __init__(
self,
checkpoint="http://tinyurl.com/5yyxdm76",
device: str | None = None,
default_dtype="float32",
**kwargs,
):
cache_dir = Path.home() / ".cache" / "mace"
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
import urllib
os.makedirs(cache_dir, exist_ok=True)
_, http_msg = urllib.request.urlretrieve(checkpoint, cached_model_path)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint}"
)
model = cached_model_path
device = device or str(get_freer_device())
super().__init__(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
)
# TODO: could share the same class with MACE_MP_Medium
class MACE_OFF_Medium(MACECalculator):
def __init__(
self,
checkpoint="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true",
device: str | None = None,
default_dtype="float32",
**kwargs,
):
cache_dir = Path.home() / ".cache" / "mace"
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
import urllib
os.makedirs(cache_dir, exist_ok=True)
_, http_msg = urllib.request.urlretrieve(checkpoint, cached_model_path)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint}"
)
model = cached_model_path
device = device or str(get_freer_device())
super().__init__(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
)
class CHGNet(CHGNetCalculator):
def __init__(
self,
checkpoint: CHGNetModel | None = None, # TODO: specifiy version
device: str | None = None,
stress_weight: float | None = 1 / 160.21766208,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
**kwargs,
) -> None:
use_device = device or str(get_freer_device())
super().__init__(
model=checkpoint,
use_device=use_device,
stress_weight=stress_weight,
on_isolated_atoms=on_isolated_atoms,
**kwargs,
)
def calculate(
self,
atoms: Atoms | None = None,
properties: list | None = None,
system_changes: list | None = None,
) -> None:
super().calculate(atoms, properties, system_changes)
# for ase.io.write compatibility
self.results.pop("crystal_fea", None)
class M3GNet(PESCalculator):
def __init__(
self,
checkpoint="M3GNet-MP-2021.2.8-PES",
# TODO: cannot assign device
state_attr: torch.Tensor | None = None,
stress_weight: float = 1.0,
**kwargs,
) -> None:
potential = matgl.load_model(checkpoint)
super().__init__(potential, state_attr, stress_weight, **kwargs)
class EquiformerV2(OCPCalculator):
def __init__(
self,
checkpoint="EquiformerV2-lE4-lF100-S2EFS-OC22", # TODO: import from registry
# TODO: cannot assign device
local_cache="/tmp/ocp/",
cpu=False,
seed=0,
**kwargs,
) -> None:
super().__init__(
model_name=checkpoint,
local_cache=local_cache,
cpu=cpu,
seed=seed,
**kwargs,
)
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
super().calculate(atoms, properties, system_changes)
self.results.update(
force=atoms.get_forces(),
)
class EquiformerV2OC20(OCPCalculator):
def __init__(
self,
checkpoint="EquiformerV2-31M-S2EF-OC20-All+MD", # TODO: import from registry
# TODO: cannot assign device
local_cache="/tmp/ocp/",
cpu=False,
seed=0,
**kwargs,
) -> None:
super().__init__(
model_name=checkpoint,
local_cache=local_cache,
cpu=cpu,
seed=seed,
**kwargs,
)
class eSCN(OCPCalculator):
def __init__(
self,
checkpoint="eSCN-L6-M3-Lay20-S2EF-OC20-All+MD", # TODO: import from registry
# TODO: cannot assign device
local_cache="/tmp/ocp/",
cpu=False,
seed=0,
**kwargs,
) -> None:
super().__init__(
model_name=checkpoint,
local_cache=local_cache,
cpu=cpu,
seed=seed,
**kwargs,
)
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
super().calculate(atoms, properties, system_changes)
self.results.update(
force=atoms.get_forces(),
)
class ALIGNN(AlignnAtomwiseCalculator):
def __init__(self, device=None, **kwargs) -> None:
# TODO: cannot control version
# _ = get_figshare_model_ff(dir_path=dir_path)
model_path = default_path()
device = device or get_freer_device()
super().__init__(path=model_path, device=device, **kwargs)
class SevenNet(SevenNetCalculator):
def __init__(
self,
checkpoint="7net-0", # TODO: import from registry
device=None,
**kwargs,
):
device = device or get_freer_device()
super().__init__(checkpoint, device=device, **kwargs)
class ORB(ORBCalculator):
def __init__(
self,
checkpoint="orbff-v1-20240827.ckpt",
device=None,
**kwargs,
):
device = device or get_freer_device()
cache_dir = Path.home() / ".cache" / "orb"
cache_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = cache_dir / "orbff-v1-20240827.ckpt"
url = f"https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/{checkpoint}"
if not ckpt_path.exists():
print(f"Downloading ORB model from {url} to {ckpt_path}...")
try:
response = requests.get(url, stream=True, timeout=120)
response.raise_for_status()
with open(ckpt_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("Download completed.")
except requests.exceptions.RequestException as e:
raise RuntimeError("Failed to download ORB model.") from e
orbff = pretrained.orb_v1(weights_path=ckpt_path, device=device)
super().__init__(orbff, device=device, **kwargs)