Spaces:
Running
Running
clearn up unused md; force to use the updated one
Browse files
mlip_arena/tasks/md.py
CHANGED
@@ -354,7 +354,7 @@ def run(
|
|
354 |
dyn.set_temperature(temperature_K=t_schedule[step])
|
355 |
if ensemble == "nvt":
|
356 |
return
|
357 |
-
dyn.set_stress(p_schedule[step]
|
358 |
pbar.update()
|
359 |
|
360 |
md_runner.attach(_callback, interval=1)
|
|
|
354 |
dyn.set_temperature(temperature_K=t_schedule[step])
|
355 |
if ensemble == "nvt":
|
356 |
return
|
357 |
+
dyn.set_stress(p_schedule[step])
|
358 |
pbar.update()
|
359 |
|
360 |
md_runner.attach(_callback, interval=1)
|
mlip_arena/tasks/stability/__init__.py
CHANGED
@@ -1,3 +0,0 @@
|
|
1 |
-
|
2 |
-
from .run import md as MD
|
3 |
-
|
|
|
|
|
|
|
|
mlip_arena/tasks/stability/run.py
DELETED
@@ -1,301 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
from datetime import datetime, timedelta
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Literal, Sequence, Tuple
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
from ase import Atoms, units
|
9 |
-
from ase.calculators.calculator import Calculator
|
10 |
-
from ase.calculators.mixing import SumCalculator
|
11 |
-
from ase.io import read
|
12 |
-
from ase.io.trajectory import Trajectory
|
13 |
-
from ase.md.andersen import Andersen
|
14 |
-
from ase.md.langevin import Langevin
|
15 |
-
from ase.md.md import MolecularDynamics
|
16 |
-
from ase.md.npt import NPT
|
17 |
-
from ase.md.nptberendsen import NPTBerendsen
|
18 |
-
from ase.md.nvtberendsen import NVTBerendsen
|
19 |
-
from ase.md.velocitydistribution import (
|
20 |
-
MaxwellBoltzmannDistribution,
|
21 |
-
Stationary,
|
22 |
-
ZeroRotation,
|
23 |
-
)
|
24 |
-
from ase.md.verlet import VelocityVerlet
|
25 |
-
from prefect import task
|
26 |
-
from prefect.tasks import task_input_hash
|
27 |
-
from scipy.interpolate import interp1d
|
28 |
-
from scipy.linalg import schur
|
29 |
-
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
|
30 |
-
from tqdm.auto import tqdm
|
31 |
-
|
32 |
-
from mlip_arena.models.utils import MLIPEnum, get_freer_device
|
33 |
-
|
34 |
-
# from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
|
35 |
-
|
36 |
-
_valid_dynamics: dict[str, tuple[str, ...]] = {
|
37 |
-
"nve": ("velocityverlet",),
|
38 |
-
"nvt": ("nose-hoover", "langevin", "andersen", "berendsen"),
|
39 |
-
"npt": ("nose-hoover", "berendsen"),
|
40 |
-
}
|
41 |
-
|
42 |
-
_preset_dynamics: dict = {
|
43 |
-
"nve_velocityverlet": VelocityVerlet,
|
44 |
-
"nvt_andersen": Andersen,
|
45 |
-
"nvt_berendsen": NVTBerendsen,
|
46 |
-
"nvt_langevin": Langevin,
|
47 |
-
"nvt_nose-hoover": NPT,
|
48 |
-
"npt_berendsen": NPTBerendsen,
|
49 |
-
"npt_nose-hoover": NPT,
|
50 |
-
}
|
51 |
-
|
52 |
-
def _interpolate_quantity(values: Sequence | np.ndarray, n_pts: int) -> np.ndarray:
|
53 |
-
"""Interpolate temperature / pressure on a schedule."""
|
54 |
-
n_vals = len(values)
|
55 |
-
return np.interp(
|
56 |
-
np.linspace(0, n_vals - 1, n_pts + 1),
|
57 |
-
np.linspace(0, n_vals - 1, n_vals),
|
58 |
-
values,
|
59 |
-
)
|
60 |
-
|
61 |
-
def _get_ensemble_schedule(
|
62 |
-
ensemble: Literal["nve", "nvt", "npt"] = "nvt",
|
63 |
-
n_steps: int = 1000,
|
64 |
-
temperature: float | Sequence | np.ndarray | None = 300.0,
|
65 |
-
pressure: float | Sequence | np.ndarray | None = None
|
66 |
-
) -> Tuple[np.ndarray, np.ndarray]:
|
67 |
-
if ensemble == "nve":
|
68 |
-
# Disable thermostat and barostat
|
69 |
-
temperature = np.nan
|
70 |
-
pressure = np.nan
|
71 |
-
t_schedule = np.full(n_steps + 1, temperature)
|
72 |
-
p_schedule = np.full(n_steps + 1, pressure)
|
73 |
-
return t_schedule, p_schedule
|
74 |
-
|
75 |
-
if isinstance(temperature, Sequence) or (
|
76 |
-
isinstance(temperature, np.ndarray) and temperature.ndim == 1
|
77 |
-
):
|
78 |
-
t_schedule = _interpolate_quantity(temperature, n_steps)
|
79 |
-
# NOTE: In ASE Langevin dynamics, the temperature are normally
|
80 |
-
# scalars, but in principle one quantity per atom could be specified by giving
|
81 |
-
# an array. This is not implemented yet here.
|
82 |
-
else:
|
83 |
-
t_schedule = np.full(n_steps + 1, temperature)
|
84 |
-
|
85 |
-
if ensemble == "nvt":
|
86 |
-
pressure = np.nan
|
87 |
-
p_schedule = np.full(n_steps + 1, pressure)
|
88 |
-
return t_schedule, p_schedule
|
89 |
-
|
90 |
-
if isinstance(pressure, Sequence) or (
|
91 |
-
isinstance(pressure, np.ndarray) and pressure.ndim == 1
|
92 |
-
):
|
93 |
-
p_schedule = _interpolate_quantity(pressure, n_steps)
|
94 |
-
elif isinstance(pressure, np.ndarray) and pressure.ndim == 4:
|
95 |
-
p_schedule = interp1d(
|
96 |
-
np.arange(n_steps + 1), pressure, kind="linear"
|
97 |
-
)
|
98 |
-
assert isinstance(p_schedule, np.ndarray)
|
99 |
-
else:
|
100 |
-
p_schedule = np.full(n_steps + 1, pressure)
|
101 |
-
|
102 |
-
return t_schedule, p_schedule
|
103 |
-
|
104 |
-
def _get_ensemble_defaults(
|
105 |
-
ensemble: Literal["nve", "nvt", "npt"],
|
106 |
-
dynamics: str | MolecularDynamics,
|
107 |
-
t_schedule: np.ndarray,
|
108 |
-
p_schedule: np.ndarray,
|
109 |
-
ase_md_kwargs: dict | None = None) -> dict:
|
110 |
-
"""Update ASE MD kwargs"""
|
111 |
-
ase_md_kwargs = ase_md_kwargs or {}
|
112 |
-
|
113 |
-
if ensemble == "nve":
|
114 |
-
ase_md_kwargs.pop("temperature", None)
|
115 |
-
ase_md_kwargs.pop("temperature_K", None)
|
116 |
-
ase_md_kwargs.pop("externalstress", None)
|
117 |
-
elif ensemble == "nvt":
|
118 |
-
ase_md_kwargs["temperature_K"] = t_schedule[0]
|
119 |
-
ase_md_kwargs.pop("externalstress", None)
|
120 |
-
elif ensemble == "npt":
|
121 |
-
ase_md_kwargs["temperature_K"] = t_schedule[0]
|
122 |
-
ase_md_kwargs["externalstress"] = p_schedule[0] # * 1e3 * units.bar
|
123 |
-
|
124 |
-
if isinstance(dynamics, str) and dynamics.lower() == "langevin":
|
125 |
-
ase_md_kwargs["friction"] = ase_md_kwargs.get(
|
126 |
-
"friction",
|
127 |
-
10.0 * 1e-3 / units.fs, # Same default as in VASP: 10 ps^-1
|
128 |
-
)
|
129 |
-
|
130 |
-
return ase_md_kwargs
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
@task(cache_key_fn=task_input_hash, cache_expiration=timedelta(days=1))
|
136 |
-
def md(
|
137 |
-
atoms: Atoms,
|
138 |
-
calculator_name: str | MLIPEnum,
|
139 |
-
calculator_kwargs: dict | None,
|
140 |
-
dispersion: str | None = None,
|
141 |
-
dispersion_kwargs: dict | None = None,
|
142 |
-
device: str | None = None,
|
143 |
-
ensemble: Literal["nve", "nvt", "npt"] = "nvt",
|
144 |
-
dynamics: str | MolecularDynamics = "langevin",
|
145 |
-
time_step: float | None = None,
|
146 |
-
total_time: float = 1000,
|
147 |
-
temperature: float | Sequence | np.ndarray | None = 300.0,
|
148 |
-
pressure: float | Sequence | np.ndarray | None = None,
|
149 |
-
ase_md_kwargs: dict | None = None,
|
150 |
-
mb_velocity_seed: int | None = None,
|
151 |
-
zero_linear_momentum: bool = True,
|
152 |
-
zero_angular_momentum: bool = True,
|
153 |
-
traj_file: str | Path | None = None,
|
154 |
-
traj_interval: int = 1,
|
155 |
-
# ttime: float = 25 * units.fs,
|
156 |
-
# pfactor: float = (75 * units.fs) ** 1 * units.GPa,
|
157 |
-
# mask: np.ndarray | list[int] | None = None,
|
158 |
-
# traceless: float = 1.0,
|
159 |
-
restart: bool = True,
|
160 |
-
# interval: int = 500,
|
161 |
-
# device: str | None = None,
|
162 |
-
# dtype: str = "float64",
|
163 |
-
):
|
164 |
-
device = device or str(get_freer_device())
|
165 |
-
|
166 |
-
print(f"Using device: {device}")
|
167 |
-
|
168 |
-
calculator_kwargs = calculator_kwargs or {}
|
169 |
-
|
170 |
-
if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
|
171 |
-
assert issubclass(calculator_name.value, Calculator)
|
172 |
-
calc = calculator_name.value(**calculator_kwargs)
|
173 |
-
elif isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_:
|
174 |
-
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
|
175 |
-
else:
|
176 |
-
raise ValueError(f"Invalid calculator: {calculator_name}")
|
177 |
-
|
178 |
-
print(f"Using calculator: {calc}")
|
179 |
-
|
180 |
-
dispersion_kwargs = dispersion_kwargs or {}
|
181 |
-
|
182 |
-
dispersion_kwargs.update({"device": device})
|
183 |
-
|
184 |
-
if dispersion is not None:
|
185 |
-
disp_calc = TorchDFTD3Calculator(
|
186 |
-
**dispersion_kwargs,
|
187 |
-
)
|
188 |
-
calc = SumCalculator([calc, disp_calc])
|
189 |
-
|
190 |
-
print(f"Using dispersion: {dispersion}")
|
191 |
-
|
192 |
-
atoms.calc = calc
|
193 |
-
|
194 |
-
if time_step is None:
|
195 |
-
# If a structure contains an isotope of hydrogen, set default `time_step`
|
196 |
-
# to 0.5 fs, and 2 fs otherwise.
|
197 |
-
has_h_isotope = "H" in atoms.get_chemical_symbols()
|
198 |
-
time_step = 0.5 if has_h_isotope else 2.0
|
199 |
-
|
200 |
-
n_steps = int(total_time / time_step)
|
201 |
-
target_steps = n_steps
|
202 |
-
|
203 |
-
t_schedule, p_schedule = _get_ensemble_schedule(
|
204 |
-
ensemble=ensemble,
|
205 |
-
n_steps=n_steps,
|
206 |
-
temperature=temperature,
|
207 |
-
pressure=pressure,
|
208 |
-
)
|
209 |
-
|
210 |
-
ase_md_kwargs = _get_ensemble_defaults(
|
211 |
-
ensemble=ensemble,
|
212 |
-
dynamics=dynamics,
|
213 |
-
t_schedule=t_schedule,
|
214 |
-
p_schedule=p_schedule,
|
215 |
-
ase_md_kwargs=ase_md_kwargs,
|
216 |
-
)
|
217 |
-
|
218 |
-
if isinstance(dynamics, str):
|
219 |
-
# Use known dynamics if `self.dynamics` is a str
|
220 |
-
dynamics = dynamics.lower()
|
221 |
-
if dynamics not in _valid_dynamics[ensemble]:
|
222 |
-
raise ValueError(
|
223 |
-
f"{dynamics} thermostat not available for {ensemble}."
|
224 |
-
f"Available {ensemble} thermostats are:"
|
225 |
-
" ".join(_valid_dynamics[ensemble])
|
226 |
-
)
|
227 |
-
|
228 |
-
if ensemble == "nve" and dynamics is None:
|
229 |
-
dynamics = "velocityverlet"
|
230 |
-
md_class = _preset_dynamics[f"{ensemble}_{dynamics}"]
|
231 |
-
elif issubclass(dynamics, MolecularDynamics):
|
232 |
-
md_class = dynamics
|
233 |
-
|
234 |
-
if md_class is NPT:
|
235 |
-
# Note that until md_func is instantiated, isinstance(md_func,NPT) is False
|
236 |
-
# ASE NPT implementation requires upper triangular cell
|
237 |
-
u, _ = schur(atoms.get_cell(complete=True), output="complex")
|
238 |
-
atoms.set_cell(u.real, scale_atoms=True)
|
239 |
-
|
240 |
-
last_step = 0
|
241 |
-
|
242 |
-
if traj_file is not None:
|
243 |
-
traj_file = Path(traj_file)
|
244 |
-
|
245 |
-
if restart and traj_file.exists():
|
246 |
-
traj = read(traj_file, index=":")
|
247 |
-
last_step = traj[-1].info.get("step", len(traj) * traj_interval)
|
248 |
-
n_steps -= last_step
|
249 |
-
last_atoms = traj[-1]
|
250 |
-
traj = Trajectory(traj_file, "a", atoms)
|
251 |
-
atoms.set_positions(last_atoms.get_positions())
|
252 |
-
atoms.set_momenta(last_atoms.get_momenta())
|
253 |
-
else:
|
254 |
-
traj = Trajectory(traj_file, "w", atoms)
|
255 |
-
|
256 |
-
if not np.isnan(t_schedule).any():
|
257 |
-
MaxwellBoltzmannDistribution(
|
258 |
-
atoms=atoms,
|
259 |
-
temperature_K=t_schedule[last_step],
|
260 |
-
rng=np.random.default_rng(seed=mb_velocity_seed),
|
261 |
-
)
|
262 |
-
|
263 |
-
if zero_linear_momentum:
|
264 |
-
Stationary(atoms)
|
265 |
-
if zero_angular_momentum:
|
266 |
-
ZeroRotation(atoms)
|
267 |
-
|
268 |
-
md_runner = md_class(
|
269 |
-
atoms=atoms,
|
270 |
-
timestep=time_step * units.fs,
|
271 |
-
**ase_md_kwargs,
|
272 |
-
)
|
273 |
-
|
274 |
-
if traj_file is not None:
|
275 |
-
md_runner.attach(traj.write, interval=traj_interval)
|
276 |
-
|
277 |
-
with tqdm(total=n_steps) as pbar:
|
278 |
-
|
279 |
-
def _callback(dyn: MolecularDynamics = md_runner) -> None:
|
280 |
-
step = last_step + dyn.nsteps
|
281 |
-
dyn.atoms.info["restart"] = last_step
|
282 |
-
dyn.atoms.info["datetime"] = datetime.now()
|
283 |
-
dyn.atoms.info["step"] = step
|
284 |
-
dyn.atoms.info["target_steps"] = target_steps
|
285 |
-
if ensemble == "nve":
|
286 |
-
return
|
287 |
-
dyn.set_temperature(temperature_K=t_schedule[step])
|
288 |
-
if ensemble == "nvt":
|
289 |
-
return
|
290 |
-
dyn.set_stress(p_schedule[step] * 1e3 * units.bar)
|
291 |
-
pbar.update()
|
292 |
-
|
293 |
-
md_runner.attach(_callback, interval=1)
|
294 |
-
|
295 |
-
start_time = datetime.now()
|
296 |
-
md_runner.run(steps=n_steps)
|
297 |
-
end_time = datetime.now()
|
298 |
-
|
299 |
-
traj.close()
|
300 |
-
|
301 |
-
return {"runtime": end_time - start_time, "n_steps": n_steps}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|