Spaces:
Running
Running
File size: 2,231 Bytes
ee784cf 49d0cfc ee784cf 49d0cfc ee784cf 0b5acc7 ee784cf 0b5acc7 ee784cf 0b5acc7 ee784cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
"""Utility functions for MLIP models."""
import importlib
from enum import Enum
from typing import Any
import torch
from ase.calculators.calculator import Calculator
from mlip_arena.models import REGISTRY
MLIPMap = {
model: getattr(
importlib.import_module(f"{__package__}.{metadata['module']}"), model,
)
for model, metadata in REGISTRY.items()
}
class EXTMLIPEnum(Enum):
"""Enumeration class for EXTMLIP models.
Attributes:
M3GNet (str): M3GNet model.
CHGNet (str): CHGNet model.
MACE (str): MACE model.
"""
M3GNet = "M3GNet"
CHGNet = "CHGNet"
MACE = "MACE"
def get_freer_device() -> torch.device:
"""Get the GPU with the most free memory.
Returns:
torch.device: The selected GPU device.
Raises:
ValueError: If no GPU is available.
"""
device_count = torch.cuda.device_count()
if device_count == 0:
print("No GPU available. Using CPU.")
return torch.device("cpu")
mem_free = [
torch.cuda.get_device_properties(i).total_memory
- torch.cuda.memory_allocated(i)
for i in range(device_count)
]
free_gpu_index = mem_free.index(max(mem_free))
print(
f"Selected GPU {free_gpu_index} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs",
)
return torch.device(f"cuda:{free_gpu_index}")
def external_ase_calculator(name: EXTMLIPEnum, **kwargs: Any) -> Calculator:
"""Construct an ASE calculator from an external third-party MLIP packages"""
calculator = None
device = get_freer_device()
if name == EXTMLIPEnum.MACE:
from mace.calculators import mace_mp
calculator = mace_mp(device=str(device), **kwargs)
elif name == EXTMLIPEnum.CHGNet:
from chgnet.model.dynamics import CHGNetCalculator
calculator = CHGNetCalculator(use_device=str(device), **kwargs)
elif name == EXTMLIPEnum.M3GNet:
import matgl
from matgl.ext.ase import PESCalculator
potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
calculator = PESCalculator(potential, **kwargs)
calculator.__setattr__("name", name.value)
return calculator
|