cyrusyc commited on
Commit
5cb0b67
1 Parent(s): 49d0cfc

update local change

Browse files
Files changed (1) hide show
  1. mlip_arena/models/mace.py +9 -7
mlip_arena/models/mace.py CHANGED
@@ -4,16 +4,16 @@ from ase.calculators.calculator import all_changes
4
  from huggingface_hub import hf_hub_download
5
  from torch_geometric.data import Data
6
 
7
- from mlip_arena.models import MLIP
8
 
9
 
10
- class MACE_MP_Medium(MLIP):
11
- def __init__(self, device: torch.device = None):
12
  fpath = hf_hub_download(
13
- repo_id="cyrusyc/mace-universal",
14
- subfolder="pretrained",
15
  filename="2023-12-12-mace-128-L1_epoch-199.model",
16
- revision=None # TODO: Add revision
17
  )
18
  super().__init__(model_path=fpath, device=device)
19
 
@@ -25,7 +25,9 @@ class MACE_MP_Medium(MLIP):
25
  "stress",
26
  ]
27
 
28
- def calculate(self, atoms: Atoms, properties: list[str], system_changes: dict = all_changes):
 
 
29
  """Calculate energies and forces for the given Atoms object"""
30
  super().calculate(atoms, properties, system_changes)
31
 
 
4
  from huggingface_hub import hf_hub_download
5
  from torch_geometric.data import Data
6
 
7
+ from mlip_arena.models import MLIPCalculator
8
 
9
 
10
+ class MACE_MP_Medium(MLIPCalculator):
11
+ def __init__(self, device: torch.device | None = None):
12
  fpath = hf_hub_download(
13
+ repo_id="cyrusyc/mace-universal",
14
+ subfolder="pretrained",
15
  filename="2023-12-12-mace-128-L1_epoch-199.model",
16
+ revision=None, # TODO: Add revision
17
  )
18
  super().__init__(model_path=fpath, device=device)
19
 
 
25
  "stress",
26
  ]
27
 
28
+ def calculate(
29
+ self, atoms: Atoms, properties: list[str], system_changes: list = all_changes
30
+ ):
31
  """Calculate energies and forces for the given Atoms object"""
32
  super().calculate(atoms, properties, system_changes)
33