cyrusyc's picture
update before hf
0b5acc7
raw
history blame
2.23 kB
"""Utility functions for MLIP models."""
import importlib
from enum import Enum
from typing import Any
import torch
from ase.calculators.calculator import Calculator
from mlip_arena.models import REGISTRY
MLIPMap = {
model: getattr(
importlib.import_module(f"{__package__}.{metadata['module']}"), model,
)
for model, metadata in REGISTRY.items()
}
class EXTMLIPEnum(Enum):
"""Enumeration class for EXTMLIP models.
Attributes:
M3GNet (str): M3GNet model.
CHGNet (str): CHGNet model.
MACE (str): MACE model.
"""
M3GNet = "M3GNet"
CHGNet = "CHGNet"
MACE = "MACE"
def get_freer_device() -> torch.device:
"""Get the GPU with the most free memory.
Returns:
torch.device: The selected GPU device.
Raises:
ValueError: If no GPU is available.
"""
device_count = torch.cuda.device_count()
if device_count == 0:
print("No GPU available. Using CPU.")
return torch.device("cpu")
mem_free = [
torch.cuda.get_device_properties(i).total_memory
- torch.cuda.memory_allocated(i)
for i in range(device_count)
]
free_gpu_index = mem_free.index(max(mem_free))
print(
f"Selected GPU {free_gpu_index} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs",
)
return torch.device(f"cuda:{free_gpu_index}")
def external_ase_calculator(name: EXTMLIPEnum, **kwargs: Any) -> Calculator:
"""Construct an ASE calculator from an external third-party MLIP packages"""
calculator = None
device = get_freer_device()
if name == EXTMLIPEnum.MACE:
from mace.calculators import mace_mp
calculator = mace_mp(device=str(device), **kwargs)
elif name == EXTMLIPEnum.CHGNet:
from chgnet.model.dynamics import CHGNetCalculator
calculator = CHGNetCalculator(use_device=str(device), **kwargs)
elif name == EXTMLIPEnum.M3GNet:
import matgl
from matgl.ext.ase import PESCalculator
potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
calculator = PESCalculator(potential, **kwargs)
calculator.__setattr__("name", name.value)
return calculator