Spaces:
Running
Running
"""Utility functions for MLIP models.""" | |
import importlib | |
from enum import Enum | |
from typing import Any | |
import torch | |
from ase.calculators.calculator import Calculator | |
from mlip_arena.models import REGISTRY | |
MLIPMap = { | |
model: getattr( | |
importlib.import_module(f"{__package__}.{metadata['module']}"), model, | |
) | |
for model, metadata in REGISTRY.items() | |
} | |
class EXTMLIPEnum(Enum): | |
"""Enumeration class for EXTMLIP models. | |
Attributes: | |
M3GNet (str): M3GNet model. | |
CHGNet (str): CHGNet model. | |
MACE (str): MACE model. | |
""" | |
M3GNet = "M3GNet" | |
CHGNet = "CHGNet" | |
MACE = "MACE" | |
def get_freer_device() -> torch.device: | |
"""Get the GPU with the most free memory. | |
Returns: | |
torch.device: The selected GPU device. | |
Raises: | |
ValueError: If no GPU is available. | |
""" | |
device_count = torch.cuda.device_count() | |
if device_count == 0: | |
print("No GPU available. Using CPU.") | |
return torch.device("cpu") | |
mem_free = [ | |
torch.cuda.get_device_properties(i).total_memory | |
- torch.cuda.memory_allocated(i) | |
for i in range(device_count) | |
] | |
free_gpu_index = mem_free.index(max(mem_free)) | |
print( | |
f"Selected GPU {free_gpu_index} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs", | |
) | |
return torch.device(f"cuda:{free_gpu_index}") | |
def external_ase_calculator(name: EXTMLIPEnum, **kwargs: Any) -> Calculator: | |
"""Construct an ASE calculator from an external third-party MLIP packages""" | |
calculator = None | |
device = get_freer_device() | |
if name == EXTMLIPEnum.MACE: | |
from mace.calculators import mace_mp | |
calculator = mace_mp(device=str(device), **kwargs) | |
elif name == EXTMLIPEnum.CHGNet: | |
from chgnet.model.dynamics import CHGNetCalculator | |
calculator = CHGNetCalculator(use_device=str(device), **kwargs) | |
elif name == EXTMLIPEnum.M3GNet: | |
import matgl | |
from matgl.ext.ase import PESCalculator | |
potential = matgl.load_model("M3GNet-MP-2021.2.8-PES") | |
calculator = PESCalculator(potential, **kwargs) | |
calculator.__setattr__("name", name.value) | |
return calculator | |