File size: 3,573 Bytes
ee784cf
 
49d0cfc
 
ee784cf
 
49d0cfc
 
 
ee784cf
 
5b01054
ee784cf
 
 
7cbf186
ee784cf
 
7cbf186
 
 
 
 
ee784cf
7cbf186
 
ee784cf
7cbf186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee784cf
7cbf186
ee784cf
 
7cbf186
 
ee784cf
7cbf186
 
 
 
 
ee784cf
7cbf186
 
 
 
ee784cf
 
7cbf186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee784cf
 
7cbf186
 
 
 
ee784cf
7cbf186
 
ee784cf
7cbf186
ee784cf
7cbf186
 
ee784cf
7cbf186
ee784cf
7cbf186
 
 
ee784cf
7cbf186
 
ee784cf
 
 
7cbf186
ee784cf
7cbf186
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Utility functions for MLIP models."""

import importlib
from enum import Enum

import torch

from mlip_arena.models import REGISTRY

MLIPMap = {
    model: getattr(
        importlib.import_module(f"{__package__}.{metadata['module']}"), metadata["class"],
    )
    for model, metadata in REGISTRY.items()
}
MLIPEnum = Enum("MLIPEnum", MLIPMap)


def get_freer_device() -> torch.device:
    """Get the GPU with the most free memory, or use MPS if available.
    s
        Returns:
            torch.device: The selected GPU device or MPS.

        Raises:
            ValueError: If no GPU or MPS is available.
    """
    device_count = torch.cuda.device_count()
    if device_count > 0:
        # If CUDA GPUs are available, select the one with the most free memory
        mem_free = [
            torch.cuda.get_device_properties(i).total_memory
            - torch.cuda.memory_allocated(i)
            for i in range(device_count)
        ]
        free_gpu_index = mem_free.index(max(mem_free))
        device = torch.device(f"cuda:{free_gpu_index}")
        print(
            f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
        )
    elif torch.backends.mps.is_available():
        # If no CUDA GPUs are available but MPS is, use MPS
        print("No GPU available. Using MPS.")
        device = torch.device("mps")
    else:
        # Fallback to CPU if neither CUDA GPUs nor MPS are available
        print("No GPU or MPS available. Using CPU.")
        device = torch.device("cpu")

    return device


# class EXTMLIPEnum(Enum):
#     """Enumeration class for EXTMLIP models.

#     Attributes:
#         M3GNet (str): M3GNet model.
#         CHGNet (str): CHGNet model.
#         MACE (str): MACE model.
#     """

#     M3GNet = "M3GNet"
#     CHGNet = "CHGNet"
#     MACE = "MACE"
#     Equiformer = "Equiformer"


# def get_freer_device() -> torch.device:
#     """Get the GPU with the most free memory.

#     Returns:
#         torch.device: The selected GPU device.

#     Raises:
#         ValueError: If no GPU is available.
#     """
#     device_count = torch.cuda.device_count()
#     if device_count == 0:
#         print("No GPU available. Using CPU.")
#         return torch.device("cpu")

#     mem_free = [
#         torch.cuda.get_device_properties(i).total_memory
#         - torch.cuda.memory_allocated(i)
#         for i in range(device_count)
#     ]

#     free_gpu_index = mem_free.index(max(mem_free))

#     print(
#         f"Selected GPU {free_gpu_index} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs",
#     )

#     return torch.device(f"cuda:{free_gpu_index}")



# def external_ase_calculator(name: EXTMLIPEnum, **kwargs: Any) -> Calculator:
#     """Construct an ASE calculator from an external third-party MLIP packages"""
#     calculator = None
#     device = get_freer_device()

#     if name == EXTMLIPEnum.MACE:
#         from mace.calculators import mace_mp

#         calculator = mace_mp(device=str(device), **kwargs)

#     elif name == EXTMLIPEnum.CHGNet:
#         from chgnet.model.dynamics import CHGNetCalculator

#         calculator = CHGNetCalculator(use_device=str(device), **kwargs)

#     elif name == EXTMLIPEnum.M3GNet:
#         import matgl
#         from matgl.ext.ase import PESCalculator

#         potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
#         calculator = PESCalculator(potential, **kwargs)



#     calculator.__setattr__("name", name.value)

#     return calculator