Yuan (Cyrus) Chiang
Add `eqV2_86M_omat_mp_salex` model (#14)
52c1bfb unverified
raw
history blame
1.24 kB
from __future__ import annotations
import os
from pathlib import Path
from mace.calculators import MACECalculator
from mlip_arena.models.utils import get_freer_device
class MACE_MP_Medium(MACECalculator):
def __init__(
self,
checkpoint="http://tinyurl.com/5yyxdm76",
device: str | None = None,
default_dtype="float32",
**kwargs,
):
cache_dir = Path.home() / ".cache" / "mace"
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
import urllib
os.makedirs(cache_dir, exist_ok=True)
_, http_msg = urllib.request.urlretrieve(checkpoint, cached_model_path)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint}"
)
model = cached_model_path
device = device or str(get_freer_device())
super().__init__(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
)