import os import urllib from typing import Literal import matgl import requests import torch from alignn.ff.ff import AlignnAtomwiseCalculator, get_figshare_model_ff, default_path from ase import Atoms from chgnet.model.dynamics import CHGNetCalculator from chgnet.model.model import CHGNet from fairchem.core import OCPCalculator from mace.calculators import MACECalculator from matgl.ext.ase import PESCalculator from sevenn.sevennet_calculator import SevenNetCalculator # Avoid circular import def get_freer_device() -> torch.device: """Get the GPU with the most free memory, or use MPS if available. s Returns: torch.device: The selected GPU device or MPS. Raises: ValueError: If no GPU or MPS is available. """ device_count = torch.cuda.device_count() if device_count > 0: # If CUDA GPUs are available, select the one with the most free memory mem_free = [ torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i) for i in range(device_count) ] free_gpu_index = mem_free.index(max(mem_free)) device = torch.device(f"cuda:{free_gpu_index}") print( f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs" ) elif torch.backends.mps.is_available(): # If no CUDA GPUs are available but MPS is, use MPS print("No GPU available. Using MPS.") device = torch.device("mps") else: # Fallback to CPU if neither CUDA GPUs nor MPS are available print("No GPU or MPS available. Using CPU.") device = torch.device("cpu") return device class MACE_MP_Medium(MACECalculator): def __init__(self, device=None, default_dtype="float32", **kwargs): checkpoint_url = "http://tinyurl.com/5yyxdm76" cache_dir = os.path.expanduser("~/.cache/mace") checkpoint_url_name = "".join( c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" ) cached_model_path = f"{cache_dir}/{checkpoint_url_name}" if not os.path.isfile(cached_model_path): os.makedirs(cache_dir, exist_ok=True) # download and save to disk print(f"Downloading MACE model from {checkpoint_url!r}") _, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path) if "Content-Type: text/html" in http_msg: raise RuntimeError( f"Model download failed, please check the URL {checkpoint_url}" ) print(f"Cached MACE model to {cached_model_path}") model = cached_model_path msg = f"Using Materials Project MACE for MACECalculator with {model}" print(msg) device = device or str(get_freer_device()) super().__init__( model_paths=model, device=device, default_dtype=default_dtype, **kwargs ) class MACE_OFF_Medium(MACECalculator): def __init__(self, device=None, default_dtype="float32", **kwargs): checkpoint_url = "https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true" cache_dir = os.path.expanduser("~/.cache/mace") checkpoint_url_name = "".join( c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" ) cached_model_path = f"{cache_dir}/{checkpoint_url_name}" if not os.path.isfile(cached_model_path): os.makedirs(cache_dir, exist_ok=True) # download and save to disk print(f"Downloading MACE model from {checkpoint_url!r}") _, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path) if "Content-Type: text/html" in http_msg: raise RuntimeError( f"Model download failed, please check the URL {checkpoint_url}" ) print(f"Cached MACE model to {cached_model_path}") model = cached_model_path msg = f"Using Materials Project MACE for MACECalculator with {model}" print(msg) device = device or str(get_freer_device()) super().__init__( model_paths=model, device=device, default_dtype=default_dtype, **kwargs ) class CHGNet(CHGNetCalculator): def __init__( self, model: CHGNet | None = None, use_device: str | None = None, stress_weight: float | None = 1 / 160.21766208, on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn", **kwargs, ) -> None: use_device = use_device or str(get_freer_device()) super().__init__( model=model, use_device=use_device, stress_weight=stress_weight, on_isolated_atoms=on_isolated_atoms, **kwargs, ) def calculate( self, atoms: Atoms | None = None, properties: list | None = None, system_changes: list | None = None, ) -> None: super().calculate(atoms, properties, system_changes) # for ase.io.write compatibility self.results.pop("crystal_fea", None) class M3GNet(PESCalculator): def __init__( self, state_attr: torch.Tensor | None = None, stress_weight: float = 1.0, **kwargs, ) -> None: potential = matgl.load_model("M3GNet-MP-2021.2.8-PES") super().__init__(potential, state_attr, stress_weight, **kwargs) class EquiformerV2(OCPCalculator): def __init__( self, model_name="EquiformerV2-lE4-lF100-S2EFS-OC22", local_cache="/tmp/ocp/", cpu=False, seed=0, **kwargs, ) -> None: super().__init__( model_name=model_name, local_cache=local_cache, cpu=cpu, seed=0, **kwargs, ) def calculate(self, atoms: Atoms, properties, system_changes) -> None: super().calculate(atoms, properties, system_changes) self.results.update( force=atoms.get_forces(), ) class EquiformerV2OC20(OCPCalculator): def __init__( self, model_name="EquiformerV2-31M-S2EF-OC20-All+MD", local_cache="/tmp/ocp/", cpu=False, seed=0, **kwargs, ) -> None: super().__init__( model_name=model_name, local_cache=local_cache, cpu=cpu, seed=0, **kwargs, ) class eSCN(OCPCalculator): def __init__( self, model_name="eSCN-L6-M3-Lay20-S2EF-OC20-All+MD", local_cache="/tmp/ocp/", cpu=False, seed=0, **kwargs, ) -> None: super().__init__( model_name=model_name, local_cache=local_cache, cpu=cpu, seed=0, **kwargs, ) def calculate(self, atoms: Atoms, properties, system_changes) -> None: super().calculate(atoms, properties, system_changes) self.results.update( force=atoms.get_forces(), ) class ALIGNN(AlignnAtomwiseCalculator): def __init__(self, dir_path: str = "/tmp/alignn/", device=None, **kwargs) -> None: model_path = get_figshare_model_ff(dir_path=dir_path) device = device or get_freer_device() super().__init__(path=dir_path, device=device, **kwargs) def calculate(self, atoms, properties=None, system_changes=None): super().calculate(atoms, properties, system_changes) class SevenNet(SevenNetCalculator): def __init__(self, device=None, **kwargs): # url = ( # "https://github.com/MDIL-SNU/SevenNet/raw/main/pretrained_potentials" # "/SevenNet_0__11July2024/checkpoint_sevennet_0.pth" # ) # ckpt_cache = "/tmp/sevennet_checkpoint.pth.tar" # response = requests.get(url) # with open(ckpt_cache, mode="wb") as file: # file.write(response.content) device = device or get_freer_device() super().__init__("7net-0", device=device, **kwargs)