cyrusyc commited on
Commit
8a5a7ef
1 Parent(s): e6cac5c

add m3gnet diatomic task

Browse files
mlip_arena/models/externals.py CHANGED
@@ -2,6 +2,8 @@ import os
2
  import urllib
3
  from typing import Literal
4
 
 
 
5
  import torch
6
  from alignn.ff.ff import AlignnAtomwiseCalculator, get_figshare_model_ff
7
  from ase import Atoms
@@ -9,6 +11,8 @@ from chgnet.model.dynamics import CHGNetCalculator
9
  from chgnet.model.model import CHGNet
10
  from fairchem.core import OCPCalculator
11
  from mace.calculators import MACECalculator
 
 
12
 
13
 
14
  # Avoid circular import
@@ -74,6 +78,7 @@ class MACE_MP_Medium(MACECalculator):
74
  model_paths=model, device=device, default_dtype=default_dtype, **kwargs
75
  )
76
 
 
77
  class MACE_OFF_Medium(MACECalculator):
78
  def __init__(self, device=None, default_dtype="float32", **kwargs):
79
  checkpoint_url = "https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true"
@@ -133,6 +138,17 @@ class CHGNet(CHGNetCalculator):
133
  self.results.pop("crystal_fea", None)
134
 
135
 
 
 
 
 
 
 
 
 
 
 
 
136
  class EquiformerV2(OCPCalculator):
137
  def __init__(
138
  self,
@@ -191,3 +207,19 @@ class ALIGNN(AlignnAtomwiseCalculator):
191
 
192
  def calculate(self, atoms, properties=None, system_changes=None):
193
  super().calculate(atoms, properties, system_changes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import urllib
3
  from typing import Literal
4
 
5
+ import matgl
6
+ import requests
7
  import torch
8
  from alignn.ff.ff import AlignnAtomwiseCalculator, get_figshare_model_ff
9
  from ase import Atoms
 
11
  from chgnet.model.model import CHGNet
12
  from fairchem.core import OCPCalculator
13
  from mace.calculators import MACECalculator
14
+ from matgl.ext.ase import PESCalculator
15
+ from sevenn.sevennet_calculator import SevenNetCalculator
16
 
17
 
18
  # Avoid circular import
 
78
  model_paths=model, device=device, default_dtype=default_dtype, **kwargs
79
  )
80
 
81
+
82
  class MACE_OFF_Medium(MACECalculator):
83
  def __init__(self, device=None, default_dtype="float32", **kwargs):
84
  checkpoint_url = "https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true"
 
138
  self.results.pop("crystal_fea", None)
139
 
140
 
141
+ class M3GNet(PESCalculator):
142
+ def __init__(
143
+ self,
144
+ state_attr: torch.Tensor | None = None,
145
+ stress_weight: float = 1.0,
146
+ **kwargs,
147
+ ) -> None:
148
+ potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
149
+ super().__init__(potential, state_attr, stress_weight, **kwargs)
150
+
151
+
152
  class EquiformerV2(OCPCalculator):
153
  def __init__(
154
  self,
 
207
 
208
  def calculate(self, atoms, properties=None, system_changes=None):
209
  super().calculate(atoms, properties, system_changes)
210
+
211
+
212
+ class SevenNet(SevenNetCalculator):
213
+ def __init__(self, device=None, **kwargs):
214
+ url = (
215
+ "https://github.com/MDIL-SNU/SevenNet/raw/main/pretrained_potentials"
216
+ "/SevenNet_0__11July2024/checkpoint_sevennet_0.pth"
217
+ )
218
+ ckpt_cache = "/tmp/sevennet_checkpoint.pth.tar"
219
+ response = requests.get(url, timeout=20)
220
+ with open(ckpt_cache, mode="wb") as file:
221
+ file.write(response.content)
222
+
223
+ device = device or get_freer_device()
224
+
225
+ super().__init__(ckpt_cache, device=device, **kwargs)
mlip_arena/models/registry.yaml CHANGED
@@ -38,6 +38,21 @@ CHGNet:
38
  nvt: true
39
  npt: true
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  EquiformerV2(OC22):
42
  module: externals
43
  class: EquiformerV2
@@ -108,4 +123,25 @@ ALIGNN:
108
  - homonuclear-diatomics
109
  prediction: EFS
110
  nvt: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  npt: true
 
38
  nvt: true
39
  npt: true
40
 
41
+ M3GNet:
42
+ module: externals
43
+ class: M3GNet
44
+ family: m3gnet
45
+ username: cyrusyc
46
+ last-update: 2024-07-08T00:00:00
47
+ datetime: 2024-07-08T00:00:00
48
+ datasets:
49
+ - atomind/mptrj
50
+ gpu-tasks:
51
+ - homonuclear-diatomics
52
+ prediction: EFS
53
+ nvt: true
54
+ npt: true
55
+
56
  EquiformerV2(OC22):
57
  module: externals
58
  class: EquiformerV2
 
123
  - homonuclear-diatomics
124
  prediction: EFS
125
  nvt: true
126
+ npt: true
127
+
128
+ SevenNet:
129
+ module: externals
130
+ class: SevenNet
131
+ family: sevennet
132
+ username: cyrusyc
133
+ last-update: 2024-03-25T14:30:00
134
+ datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
135
+ datasets:
136
+ - atomind/mptrj # TODO: fake HF dataset repo
137
+ cpu-tasks:
138
+ - alexandria
139
+ - qmof
140
+ gpu-tasks:
141
+ - homonuclear-diatomics
142
+ github: https://github.com/ACEsuit/mace
143
+ doi: https://arxiv.org/abs/2401.00096
144
+ date: 2023-12-29
145
+ prediction: EFS
146
+ nvt: true
147
  npt: true
mlip_arena/tasks/diatomics/m3gnet/homonuclear-diatomics.json ADDED
The diff for this file is too large to render. See raw diff
 
mlip_arena/tasks/diatomics/m3gnet/run.ipynb DELETED
@@ -1,295 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 6,
6
- "id": "3200850a-b8fb-4f50-9815-16ae8da0f942",
7
- "metadata": {
8
- "tags": []
9
- },
10
- "outputs": [
11
- {
12
- "name": "stdin",
13
- "output_type": "stream",
14
- "text": [
15
- "Do you really want to delete everything in /global/homes/c/cyrusyc/.cache/matgl (y|n)? y\n"
16
- ]
17
- },
18
- {
19
- "ename": "ValueError",
20
- "evalue": "Bad serialized model or bad model name. It is possible that you have an older model cached. Please clear your cache by running `python -c \"import matgl; matgl.clear_cache()\"`",
21
- "output_type": "error",
22
- "traceback": [
23
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
24
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
25
- "File \u001b[0;32m/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/matgl/utils/io.py:212\u001b[0m, in \u001b[0;36mload_model\u001b[0;34m(path, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m cls_ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(mod, classname)\n\u001b[0;32m--> 212\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcls_\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfpaths\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
26
- "File \u001b[0;32m/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/matgl/utils/io.py:129\u001b[0m, in \u001b[0;36mIOMixIn.load\u001b[0;34m(cls, path, **kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m d \u001b[38;5;241m=\u001b[39m {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m d\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m k\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m@\u001b[39m\u001b[38;5;124m\"\u001b[39m)}\n\u001b[0;32m--> 129\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43md\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 130\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(state, strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n",
27
- "\u001b[0;31mTypeError\u001b[0m: Potential.__init__() got an unexpected keyword argument 'calc_magmom'",
28
- "\nThe above exception was the direct cause of the following exception:\n",
29
- "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
30
- "Cell \u001b[0;32mIn[6], line 18\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\n\u001b[1;32m 17\u001b[0m matgl\u001b[38;5;241m.\u001b[39mclear_cache()\n\u001b[0;32m---> 18\u001b[0m potential \u001b[38;5;241m=\u001b[39m \u001b[43mmatgl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mM3GNet-MP-2021.2.8-PES\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 19\u001b[0m calculator \u001b[38;5;241m=\u001b[39m PESCalculator(potential)\n",
31
- "File \u001b[0;32m/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/matgl/utils/io.py:214\u001b[0m, in \u001b[0;36mload_model\u001b[0;34m(path, **kwargs)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m cls_\u001b[38;5;241m.\u001b[39mload(fpaths, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[0;32m--> 214\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 215\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBad serialized model or bad model name. It is possible that you have an older model cached. Please \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mclear your cache by running `python -c \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimport matgl; matgl.clear_cache()\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 217\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n",
32
- "\u001b[0;31mValueError\u001b[0m: Bad serialized model or bad model name. It is possible that you have an older model cached. Please clear your cache by running `python -c \"import matgl; matgl.clear_cache()\"`"
33
- ]
34
- }
35
- ],
36
- "source": [
37
- "import os\n",
38
- "import numpy as np\n",
39
- "\n",
40
- "from ase import Atoms, Atom\n",
41
- "from ase.io import read, write\n",
42
- "from ase.data import chemical_symbols, covalent_radii, vdw_alvarez\n",
43
- "from ase.parallel import paropen as open\n",
44
- "\n",
45
- "from pathlib import Path\n",
46
- "from pymatgen.core import Element\n",
47
- "import pandas as pd\n",
48
- "\n",
49
- "from tqdm.auto import tqdm\n",
50
- "\n",
51
- "import matgl\n",
52
- "from matgl.ext.ase import PESCalculator\n",
53
- "\n",
54
- "matgl.clear_cache()\n",
55
- "potential = matgl.load_model(\"M3GNet-MP-2021.2.8-PES\")\n",
56
- "calculator = PESCalculator(potential)\n"
57
- ]
58
- },
59
- {
60
- "cell_type": "code",
61
- "execution_count": 2,
62
- "id": "90887faa-1601-4c4c-9c44-d16731471d7f",
63
- "metadata": {
64
- "scrolled": true,
65
- "tags": []
66
- },
67
- "outputs": [
68
- {
69
- "ename": "ValueError",
70
- "evalue": "Bad serialized model or bad model name. It is possible that you have an older model cached. Please clear your cache by running `python -c \"import matgl; matgl.clear_cache()\"`",
71
- "output_type": "error",
72
- "traceback": [
73
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
74
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
75
- "File \u001b[0;32m/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/matgl/utils/io.py:212\u001b[0m, in \u001b[0;36mload_model\u001b[0;34m(path, **kwargs)\u001b[0m\n\u001b[1;32m 211\u001b[0m cls_ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(mod, classname)\n\u001b[0;32m--> 212\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcls_\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfpaths\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
76
- "File \u001b[0;32m/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/matgl/utils/io.py:129\u001b[0m, in \u001b[0;36mIOMixIn.load\u001b[0;34m(cls, path, **kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m d \u001b[38;5;241m=\u001b[39m {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m d\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m k\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m@\u001b[39m\u001b[38;5;124m\"\u001b[39m)}\n\u001b[0;32m--> 129\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43md\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 130\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(state, strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n",
77
- "\u001b[0;31mTypeError\u001b[0m: Potential.__init__() got an unexpected keyword argument 'calc_magmom'",
78
- "\nThe above exception was the direct cause of the following exception:\n",
79
- "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
80
- "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m potential \u001b[38;5;241m=\u001b[39m \u001b[43mmatgl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_model\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mM3GNet-MP-2021.2.8-PES\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m calculator \u001b[38;5;241m=\u001b[39m PESCalculator(potential)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m symbol \u001b[38;5;129;01min\u001b[39;00m tqdm(chemical_symbols):\n",
81
- "File \u001b[0;32m/pscratch/sd/c/cyrusyc/.conda/mlip-arena/lib/python3.11/site-packages/matgl/utils/io.py:214\u001b[0m, in \u001b[0;36mload_model\u001b[0;34m(path, **kwargs)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m cls_\u001b[38;5;241m.\u001b[39mload(fpaths, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[0;32m--> 214\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 215\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBad serialized model or bad model name. It is possible that you have an older model cached. Please \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mclear your cache by running `python -c \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimport matgl; matgl.clear_cache()\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 217\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n",
82
- "\u001b[0;31mValueError\u001b[0m: Bad serialized model or bad model name. It is possible that you have an older model cached. Please clear your cache by running `python -c \"import matgl; matgl.clear_cache()\"`"
83
- ]
84
- }
85
- ],
86
- "source": [
87
- "\n",
88
- "\n",
89
- "for symbol in tqdm(chemical_symbols):\n",
90
- " \n",
91
- " s = set([symbol])\n",
92
- " \n",
93
- " if 'X' in s:\n",
94
- " continue\n",
95
- " \n",
96
- " try:\n",
97
- " atom = Atom(symbol)\n",
98
- " rmin = covalent_radii[atom.number] * 0.95\n",
99
- " rvdw = vdw_alvarez.vdw_radii[atom.number] if atom.number < len(vdw_alvarez.vdw_radii) else np.nan \n",
100
- " rmax = 3.1 * rvdw if not np.isnan(rvdw) else 6\n",
101
- " rstep = 0.01 #if rmin < 1 else 0.4\n",
102
- "\n",
103
- " a = 2 * rmax\n",
104
- "\n",
105
- " npts = int((rmax - rmin)/rstep)\n",
106
- "\n",
107
- " rs = np.linspace(rmin, rmax, npts)\n",
108
- " e = np.zeros_like(rs)\n",
109
- "\n",
110
- " da = symbol + symbol\n",
111
- "\n",
112
- " out_dir = Path(str(da))\n",
113
- "\n",
114
- " os.makedirs(out_dir, exist_ok=True)\n",
115
- "\n",
116
- " skip = 0\n",
117
- " \n",
118
- " element = Element(symbol)\n",
119
- " \n",
120
- " try:\n",
121
- " m = element.valence[1]\n",
122
- " if element.valence == (0, 2):\n",
123
- " m = 0\n",
124
- " except:\n",
125
- " m = 0\n",
126
- " \n",
127
- " \n",
128
- " r = rs[0]\n",
129
- " \n",
130
- " positions = [\n",
131
- " [a/2-r/2, a/2, a/2],\n",
132
- " [a/2+r/2, a/2, a/2],\n",
133
- " ]\n",
134
- " \n",
135
- " traj_fpath = out_dir / \"traj.extxyz\"\n",
136
- "\n",
137
- " if traj_fpath.exists():\n",
138
- " traj = read(traj_fpath, index=\":\")\n",
139
- " skip = len(traj)\n",
140
- " atoms = traj[-1]\n",
141
- " else:\n",
142
- " # Create the unit cell with two atoms\n",
143
- " atoms = Atoms(\n",
144
- " da, \n",
145
- " positions=positions,\n",
146
- " # magmoms=magmoms,\n",
147
- " cell=[a, a+0.001, a+0.002], \n",
148
- " pbc=True\n",
149
- " )\n",
150
- " \n",
151
- " print(atoms)\n",
152
- "\n",
153
- " calc = calculator\n",
154
- "\n",
155
- " atoms.calc = calc\n",
156
- " \n",
157
- " # cdft = CDFT(calc=calc, atoms=atoms, spinspin_regions= \n",
158
- " # atoms.calc = cdft\n",
159
- "\n",
160
- " for i, r in enumerate(tqdm(np.flip(rs))):\n",
161
- "\n",
162
- " if i < skip:\n",
163
- " continue\n",
164
- "\n",
165
- " positions = [\n",
166
- " [a/2-r/2, a/2, a/2],\n",
167
- " [a/2+r/2, a/2, a/2],\n",
168
- " ]\n",
169
- " \n",
170
- " # atoms.set_initial_magnetic_moments(magmoms)\n",
171
- " \n",
172
- " atoms.set_positions(positions)\n",
173
- "\n",
174
- " e[i] = atoms.get_potential_energy()\n",
175
- " \n",
176
- " atoms.calc.results.update({\n",
177
- " \"forces\": atoms.get_forces()\n",
178
- " })\n",
179
- "\n",
180
- " write(traj_fpath, atoms, append=\"a\")\n",
181
- " except Exception as e:\n",
182
- " print(e)\n"
183
- ]
184
- },
185
- {
186
- "cell_type": "code",
187
- "execution_count": 2,
188
- "id": "a0ac2c09-370b-4fdd-bf74-ea5c4ade0215",
189
- "metadata": {},
190
- "outputs": [
191
- {
192
- "data": {
193
- "application/vnd.jupyter.widget-view+json": {
194
- "model_id": "cc766db4ce844c40848791e14a71832c",
195
- "version_major": 2,
196
- "version_minor": 0
197
- },
198
- "text/plain": [
199
- " 0%| | 0/119 [00:00<?, ?it/s]"
200
- ]
201
- },
202
- "metadata": {},
203
- "output_type": "display_data"
204
- }
205
- ],
206
- "source": [
207
- "\n",
208
- "\n",
209
- "df = pd.DataFrame(columns=['name', 'method', 'R', 'E', 'F', 'S^2'])\n",
210
- "\n",
211
- "for symbol in tqdm(chemical_symbols):\n",
212
- " \n",
213
- " da = symbol + symbol\n",
214
- " \n",
215
- " out_dir = Path(da)\n",
216
- " \n",
217
- " traj_fpath = out_dir / \"traj.extxyz\"\n",
218
- "\n",
219
- " if traj_fpath.exists():\n",
220
- " traj = read(traj_fpath, index=\":\")\n",
221
- " else:\n",
222
- " continue\n",
223
- " \n",
224
- " Rs, Es, Fs, S2s = [], [], [], []\n",
225
- " for atoms in traj:\n",
226
- " \n",
227
- " vec = atoms.positions[1] - atoms.positions[0]\n",
228
- " r = np.linalg.norm(vec)\n",
229
- " e = atoms.get_potential_energy()\n",
230
- " f = np.inner(vec/r, atoms.get_forces()[1])\n",
231
- " # s2 = np.mean(np.power(atoms.get_magnetic_moments(), 2))\n",
232
- " \n",
233
- " Rs.append(r)\n",
234
- " Es.append(e)\n",
235
- " Fs.append(f)\n",
236
- " # S2s.append(s2)\n",
237
- " \n",
238
- " data = {\n",
239
- " 'name': da,\n",
240
- " 'method': 'M3GNet',\n",
241
- " 'R': Rs,\n",
242
- " 'E': Es,\n",
243
- " 'F': Fs,\n",
244
- " 'S^2': S2s\n",
245
- " }\n",
246
- "\n",
247
- " df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)\n",
248
- "\n",
249
- "json_fpath = 'homonuclear-diatomics.json'\n",
250
- "\n",
251
- "df.to_json(json_fpath, orient='records') "
252
- ]
253
- },
254
- {
255
- "cell_type": "code",
256
- "execution_count": null,
257
- "id": "2207f50e-63a1-4199-b2e1-a11858af5108",
258
- "metadata": {
259
- "tags": []
260
- },
261
- "outputs": [],
262
- "source": [
263
- "df"
264
- ]
265
- }
266
- ],
267
- "metadata": {
268
- "kernelspec": {
269
- "display_name": "mlip-arena",
270
- "language": "python",
271
- "name": "mlip-arena"
272
- },
273
- "language_info": {
274
- "codemirror_mode": {
275
- "name": "ipython",
276
- "version": 3
277
- },
278
- "file_extension": ".py",
279
- "mimetype": "text/x-python",
280
- "name": "python",
281
- "nbconvert_exporter": "python",
282
- "pygments_lexer": "ipython3",
283
- "version": "3.11.8"
284
- },
285
- "widgets": {
286
- "application/vnd.jupyter.widget-state+json": {
287
- "state": {},
288
- "version_major": 2,
289
- "version_minor": 0
290
- }
291
- }
292
- },
293
- "nbformat": 4,
294
- "nbformat_minor": 5
295
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mlip_arena/tasks/diatomics/mace-mp/homonuclear-diatomics.json CHANGED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -5,7 +5,7 @@ scipy
5
  ase==3.23.0
6
  torch==2.2.1
7
  pymatgen==2024.4.13
8
- bokeh==2.4.3
9
  statsmodels==0.14.2
10
  # py3Dmol==2.0.0.post2
11
  # stmol==0.0.9
 
5
  ase==3.23.0
6
  torch==2.2.1
7
  pymatgen==2024.4.13
8
+ bokeh
9
  statsmodels==0.14.2
10
  # py3Dmol==2.0.0.post2
11
  # stmol==0.0.9
serve/tasks/homonuclear-diatomics.py CHANGED
@@ -23,7 +23,7 @@ The potential energy curves of homonuclear diatomics are the most fundamental in
23
  st.markdown("### Methods")
24
  container = st.container(border=True)
25
  valid_models = [model for model, metadata in REGISTRY.items() if Path(__file__).stem in metadata.get("gpu-tasks", [])]
26
- methods = container.multiselect("MLIPs", valid_models, ["MACE-MP(M)", "EquiformerV2(OC22)", "CHGNet", "eSCN(OC20)", "ALIGNN"])
27
  dft_methods = container.multiselect("DFT Methods", ["GPAW"], [])
28
 
29
  st.markdown("### Settings")
@@ -61,6 +61,9 @@ df.drop_duplicates(inplace=True, subset=["name", "method"])
61
 
62
  method_color_mapping = {method: color_sequence[i % len(color_sequence)] for i, method in enumerate(df["method"].unique())}
63
 
 
 
 
64
  for i, symbol in enumerate(chemical_symbols[1:]):
65
 
66
  if i % ncols == 0:
@@ -162,7 +165,7 @@ for i, symbol in enumerate(chemical_symbols[1:]):
162
  yaxis=dict(
163
  title=dict(text="Energy [eV]"),
164
  side="left",
165
- range=[elo, 2*(abs(elo))],
166
  )
167
  )
168
 
@@ -172,10 +175,11 @@ for i, symbol in enumerate(chemical_symbols[1:]):
172
  yaxis2=dict(
173
  title=dict(text="Force [eV/Å]"),
174
  side="right",
175
- range=[flo, 1.5*abs(flo)],
176
  overlaying="y",
177
  tickmode="sync",
178
  ),
179
  )
180
 
181
  cols[i % ncols].plotly_chart(fig, use_container_width=True)
 
 
23
  st.markdown("### Methods")
24
  container = st.container(border=True)
25
  valid_models = [model for model, metadata in REGISTRY.items() if Path(__file__).stem in metadata.get("gpu-tasks", [])]
26
+ methods = container.multiselect("MLIPs", valid_models, ["EquiformerV2(OC22)", "eSCN(OC20)", "CHGNet", "M3GNet", "MACE-MP(M)"])
27
  dft_methods = container.multiselect("DFT Methods", ["GPAW"], [])
28
 
29
  st.markdown("### Settings")
 
61
 
62
  method_color_mapping = {method: color_sequence[i % len(color_sequence)] for i, method in enumerate(df["method"].unique())}
63
 
64
+ # img_dir = Path('./images')
65
+ # img_dir.mkdir(exist_ok=True)
66
+
67
  for i, symbol in enumerate(chemical_symbols[1:]):
68
 
69
  if i % ncols == 0:
 
165
  yaxis=dict(
166
  title=dict(text="Energy [eV]"),
167
  side="left",
168
+ range=[elo, 1.5*(abs(elo))],
169
  )
170
  )
171
 
 
175
  yaxis2=dict(
176
  title=dict(text="Force [eV/Å]"),
177
  side="right",
178
+ range=[flo, 1.0*abs(flo)],
179
  overlaying="y",
180
  tickmode="sync",
181
  ),
182
  )
183
 
184
  cols[i % ncols].plotly_chart(fig, use_container_width=True)
185
+ # fig.write_image(format='svg', file=img_dir / f"{name}.svg")