File size: 2,231 Bytes
ee784cf
 
49d0cfc
 
ee784cf
 
 
 
49d0cfc
 
 
ee784cf
 
0b5acc7
ee784cf
 
 
 
 
 
0b5acc7
ee784cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b5acc7
ee784cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Utility functions for MLIP models."""

import importlib
from enum import Enum
from typing import Any

import torch
from ase.calculators.calculator import Calculator

from mlip_arena.models import REGISTRY

MLIPMap = {
    model: getattr(
        importlib.import_module(f"{__package__}.{metadata['module']}"), model,
    )
    for model, metadata in REGISTRY.items()
}


class EXTMLIPEnum(Enum):
    """Enumeration class for EXTMLIP models.

    Attributes:
        M3GNet (str): M3GNet model.
        CHGNet (str): CHGNet model.
        MACE (str): MACE model.
    """

    M3GNet = "M3GNet"
    CHGNet = "CHGNet"
    MACE = "MACE"


def get_freer_device() -> torch.device:
    """Get the GPU with the most free memory.

    Returns:
        torch.device: The selected GPU device.

    Raises:
        ValueError: If no GPU is available.
    """
    device_count = torch.cuda.device_count()
    if device_count == 0:
        print("No GPU available. Using CPU.")
        return torch.device("cpu")

    mem_free = [
        torch.cuda.get_device_properties(i).total_memory
        - torch.cuda.memory_allocated(i)
        for i in range(device_count)
    ]

    free_gpu_index = mem_free.index(max(mem_free))

    print(
        f"Selected GPU {free_gpu_index} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs",
    )

    return torch.device(f"cuda:{free_gpu_index}")


def external_ase_calculator(name: EXTMLIPEnum, **kwargs: Any) -> Calculator:
    """Construct an ASE calculator from an external third-party MLIP packages"""
    calculator = None
    device = get_freer_device()

    if name == EXTMLIPEnum.MACE:
        from mace.calculators import mace_mp

        calculator = mace_mp(device=str(device), **kwargs)

    elif name == EXTMLIPEnum.CHGNet:
        from chgnet.model.dynamics import CHGNetCalculator

        calculator = CHGNetCalculator(use_device=str(device), **kwargs)

    elif name == EXTMLIPEnum.M3GNet:
        import matgl
        from matgl.ext.ase import PESCalculator

        potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
        calculator = PESCalculator(potential, **kwargs)

    calculator.__setattr__("name", name.value)

    return calculator