from __future__ import annotations import os from pathlib import Path from typing import Literal import matgl import requests import torch from alignn.ff.ff import AlignnAtomwiseCalculator, get_figshare_model_ff, default_path from ase import Atoms from chgnet.model.dynamics import CHGNetCalculator from chgnet.model.model import CHGNet as CHGNetModel from fairchem.core import OCPCalculator from mace.calculators import MACECalculator from matgl.ext.ase import PESCalculator from orb_models.forcefield import pretrained from orb_models.forcefield.calculator import ORBCalculator from sevenn.sevennet_calculator import SevenNetCalculator # Avoid circular import def get_freer_device() -> torch.device: """Get the GPU with the most free memory, or use MPS if available. s Returns: torch.device: The selected GPU device or MPS. Raises: ValueError: If no GPU or MPS is available. """ device_count = torch.cuda.device_count() if device_count > 0: # If CUDA GPUs are available, select the one with the most free memory mem_free = [ torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i) for i in range(device_count) ] free_gpu_index = mem_free.index(max(mem_free)) device = torch.device(f"cuda:{free_gpu_index}") print( f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs" ) elif torch.backends.mps.is_available(): # If no CUDA GPUs are available but MPS is, use MPS print("No GPU available. Using MPS.") device = torch.device("mps") else: # Fallback to CPU if neither CUDA GPUs nor MPS are available print("No GPU or MPS available. Using CPU.") device = torch.device("cpu") return device class MACE_MP_Medium(MACECalculator): def __init__( self, checkpoint="http://tinyurl.com/5yyxdm76", device: str | None = None, default_dtype="float32", **kwargs, ): cache_dir = Path.home() / ".cache" / "mace" checkpoint_url_name = "".join( c for c in os.path.basename(checkpoint) if c.isalnum() or c in "_" ) cached_model_path = f"{cache_dir}/{checkpoint_url_name}" if not os.path.isfile(cached_model_path): import urllib os.makedirs(cache_dir, exist_ok=True) _, http_msg = urllib.request.urlretrieve(checkpoint, cached_model_path) if "Content-Type: text/html" in http_msg: raise RuntimeError( f"Model download failed, please check the URL {checkpoint}" ) model = cached_model_path device = device or str(get_freer_device()) super().__init__( model_paths=model, device=device, default_dtype=default_dtype, **kwargs ) # TODO: could share the same class with MACE_MP_Medium class MACE_OFF_Medium(MACECalculator): def __init__( self, checkpoint="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", device: str | None = None, default_dtype="float32", **kwargs, ): cache_dir = Path.home() / ".cache" / "mace" checkpoint_url_name = "".join( c for c in os.path.basename(checkpoint) if c.isalnum() or c in "_" ) cached_model_path = f"{cache_dir}/{checkpoint_url_name}" if not os.path.isfile(cached_model_path): import urllib os.makedirs(cache_dir, exist_ok=True) _, http_msg = urllib.request.urlretrieve(checkpoint, cached_model_path) if "Content-Type: text/html" in http_msg: raise RuntimeError( f"Model download failed, please check the URL {checkpoint}" ) model = cached_model_path device = device or str(get_freer_device()) super().__init__( model_paths=model, device=device, default_dtype=default_dtype, **kwargs ) class CHGNet(CHGNetCalculator): def __init__( self, checkpoint: CHGNetModel | None = None, # TODO: specifiy version device: str | None = None, stress_weight: float | None = 1 / 160.21766208, on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn", **kwargs, ) -> None: use_device = device or str(get_freer_device()) super().__init__( model=checkpoint, use_device=use_device, stress_weight=stress_weight, on_isolated_atoms=on_isolated_atoms, **kwargs, ) def calculate( self, atoms: Atoms | None = None, properties: list | None = None, system_changes: list | None = None, ) -> None: super().calculate(atoms, properties, system_changes) # for ase.io.write compatibility self.results.pop("crystal_fea", None) class M3GNet(PESCalculator): def __init__( self, checkpoint="M3GNet-MP-2021.2.8-PES", # TODO: cannot assign device state_attr: torch.Tensor | None = None, stress_weight: float = 1.0, **kwargs, ) -> None: potential = matgl.load_model(checkpoint) super().__init__(potential, state_attr, stress_weight, **kwargs) class EquiformerV2(OCPCalculator): def __init__( self, checkpoint="EquiformerV2-lE4-lF100-S2EFS-OC22", # TODO: import from registry # TODO: cannot assign device local_cache="/tmp/ocp/", cpu=False, seed=0, **kwargs, ) -> None: super().__init__( model_name=checkpoint, local_cache=local_cache, cpu=cpu, seed=seed, **kwargs, ) def calculate(self, atoms: Atoms, properties, system_changes) -> None: super().calculate(atoms, properties, system_changes) self.results.update( force=atoms.get_forces(), ) class EquiformerV2OC20(OCPCalculator): def __init__( self, checkpoint="EquiformerV2-31M-S2EF-OC20-All+MD", # TODO: import from registry # TODO: cannot assign device local_cache="/tmp/ocp/", cpu=False, seed=0, **kwargs, ) -> None: super().__init__( model_name=checkpoint, local_cache=local_cache, cpu=cpu, seed=seed, **kwargs, ) class eSCN(OCPCalculator): def __init__( self, checkpoint="eSCN-L6-M3-Lay20-S2EF-OC20-All+MD", # TODO: import from registry # TODO: cannot assign device local_cache="/tmp/ocp/", cpu=False, seed=0, **kwargs, ) -> None: super().__init__( model_name=checkpoint, local_cache=local_cache, cpu=cpu, seed=seed, **kwargs, ) def calculate(self, atoms: Atoms, properties, system_changes) -> None: super().calculate(atoms, properties, system_changes) self.results.update( force=atoms.get_forces(), ) class ALIGNN(AlignnAtomwiseCalculator): def __init__(self, device=None, **kwargs) -> None: # TODO: cannot control version # _ = get_figshare_model_ff(dir_path=dir_path) model_path = default_path() device = device or get_freer_device() super().__init__(path=model_path, device=device, **kwargs) class SevenNet(SevenNetCalculator): def __init__( self, checkpoint="7net-0", # TODO: import from registry device=None, **kwargs, ): device = device or get_freer_device() super().__init__(checkpoint, device=device, **kwargs) class ORB(ORBCalculator): def __init__( self, checkpoint="orbff-v1-20240827.ckpt", device=None, **kwargs, ): device = device or get_freer_device() cache_dir = Path.home() / ".cache" / "orb" cache_dir.mkdir(parents=True, exist_ok=True) ckpt_path = cache_dir / "orbff-v1-20240827.ckpt" url = f"https://storage.googleapis.com/orbitalmaterials-public-models/forcefields/{checkpoint}" if not ckpt_path.exists(): print(f"Downloading ORB model from {url} to {ckpt_path}...") try: response = requests.get(url, stream=True, timeout=120) response.raise_for_status() with open(ckpt_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print("Download completed.") except requests.exceptions.RequestException as e: raise RuntimeError("Failed to download ORB model.") from e orbff = pretrained.orb_v1(weights_path=ckpt_path, device=device) super().__init__(orbff, device=device, **kwargs)