cyrusyc commited on
Commit
dd24ea1
1 Parent(s): e59afe6

clearn up unused md; force to use the updated one

Browse files
mlip_arena/tasks/md.py CHANGED
@@ -354,7 +354,7 @@ def run(
354
  dyn.set_temperature(temperature_K=t_schedule[step])
355
  if ensemble == "nvt":
356
  return
357
- dyn.set_stress(p_schedule[step] * 1e3 * units.bar)
358
  pbar.update()
359
 
360
  md_runner.attach(_callback, interval=1)
 
354
  dyn.set_temperature(temperature_K=t_schedule[step])
355
  if ensemble == "nvt":
356
  return
357
+ dyn.set_stress(p_schedule[step])
358
  pbar.update()
359
 
360
  md_runner.attach(_callback, interval=1)
mlip_arena/tasks/stability/__init__.py CHANGED
@@ -1,3 +0,0 @@
1
-
2
- from .run import md as MD
3
-
 
 
 
 
mlip_arena/tasks/stability/run.py DELETED
@@ -1,301 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from datetime import datetime, timedelta
4
- from pathlib import Path
5
- from typing import Literal, Sequence, Tuple
6
-
7
- import numpy as np
8
- from ase import Atoms, units
9
- from ase.calculators.calculator import Calculator
10
- from ase.calculators.mixing import SumCalculator
11
- from ase.io import read
12
- from ase.io.trajectory import Trajectory
13
- from ase.md.andersen import Andersen
14
- from ase.md.langevin import Langevin
15
- from ase.md.md import MolecularDynamics
16
- from ase.md.npt import NPT
17
- from ase.md.nptberendsen import NPTBerendsen
18
- from ase.md.nvtberendsen import NVTBerendsen
19
- from ase.md.velocitydistribution import (
20
- MaxwellBoltzmannDistribution,
21
- Stationary,
22
- ZeroRotation,
23
- )
24
- from ase.md.verlet import VelocityVerlet
25
- from prefect import task
26
- from prefect.tasks import task_input_hash
27
- from scipy.interpolate import interp1d
28
- from scipy.linalg import schur
29
- from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
30
- from tqdm.auto import tqdm
31
-
32
- from mlip_arena.models.utils import MLIPEnum, get_freer_device
33
-
34
- # from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
35
-
36
- _valid_dynamics: dict[str, tuple[str, ...]] = {
37
- "nve": ("velocityverlet",),
38
- "nvt": ("nose-hoover", "langevin", "andersen", "berendsen"),
39
- "npt": ("nose-hoover", "berendsen"),
40
- }
41
-
42
- _preset_dynamics: dict = {
43
- "nve_velocityverlet": VelocityVerlet,
44
- "nvt_andersen": Andersen,
45
- "nvt_berendsen": NVTBerendsen,
46
- "nvt_langevin": Langevin,
47
- "nvt_nose-hoover": NPT,
48
- "npt_berendsen": NPTBerendsen,
49
- "npt_nose-hoover": NPT,
50
- }
51
-
52
- def _interpolate_quantity(values: Sequence | np.ndarray, n_pts: int) -> np.ndarray:
53
- """Interpolate temperature / pressure on a schedule."""
54
- n_vals = len(values)
55
- return np.interp(
56
- np.linspace(0, n_vals - 1, n_pts + 1),
57
- np.linspace(0, n_vals - 1, n_vals),
58
- values,
59
- )
60
-
61
- def _get_ensemble_schedule(
62
- ensemble: Literal["nve", "nvt", "npt"] = "nvt",
63
- n_steps: int = 1000,
64
- temperature: float | Sequence | np.ndarray | None = 300.0,
65
- pressure: float | Sequence | np.ndarray | None = None
66
- ) -> Tuple[np.ndarray, np.ndarray]:
67
- if ensemble == "nve":
68
- # Disable thermostat and barostat
69
- temperature = np.nan
70
- pressure = np.nan
71
- t_schedule = np.full(n_steps + 1, temperature)
72
- p_schedule = np.full(n_steps + 1, pressure)
73
- return t_schedule, p_schedule
74
-
75
- if isinstance(temperature, Sequence) or (
76
- isinstance(temperature, np.ndarray) and temperature.ndim == 1
77
- ):
78
- t_schedule = _interpolate_quantity(temperature, n_steps)
79
- # NOTE: In ASE Langevin dynamics, the temperature are normally
80
- # scalars, but in principle one quantity per atom could be specified by giving
81
- # an array. This is not implemented yet here.
82
- else:
83
- t_schedule = np.full(n_steps + 1, temperature)
84
-
85
- if ensemble == "nvt":
86
- pressure = np.nan
87
- p_schedule = np.full(n_steps + 1, pressure)
88
- return t_schedule, p_schedule
89
-
90
- if isinstance(pressure, Sequence) or (
91
- isinstance(pressure, np.ndarray) and pressure.ndim == 1
92
- ):
93
- p_schedule = _interpolate_quantity(pressure, n_steps)
94
- elif isinstance(pressure, np.ndarray) and pressure.ndim == 4:
95
- p_schedule = interp1d(
96
- np.arange(n_steps + 1), pressure, kind="linear"
97
- )
98
- assert isinstance(p_schedule, np.ndarray)
99
- else:
100
- p_schedule = np.full(n_steps + 1, pressure)
101
-
102
- return t_schedule, p_schedule
103
-
104
- def _get_ensemble_defaults(
105
- ensemble: Literal["nve", "nvt", "npt"],
106
- dynamics: str | MolecularDynamics,
107
- t_schedule: np.ndarray,
108
- p_schedule: np.ndarray,
109
- ase_md_kwargs: dict | None = None) -> dict:
110
- """Update ASE MD kwargs"""
111
- ase_md_kwargs = ase_md_kwargs or {}
112
-
113
- if ensemble == "nve":
114
- ase_md_kwargs.pop("temperature", None)
115
- ase_md_kwargs.pop("temperature_K", None)
116
- ase_md_kwargs.pop("externalstress", None)
117
- elif ensemble == "nvt":
118
- ase_md_kwargs["temperature_K"] = t_schedule[0]
119
- ase_md_kwargs.pop("externalstress", None)
120
- elif ensemble == "npt":
121
- ase_md_kwargs["temperature_K"] = t_schedule[0]
122
- ase_md_kwargs["externalstress"] = p_schedule[0] # * 1e3 * units.bar
123
-
124
- if isinstance(dynamics, str) and dynamics.lower() == "langevin":
125
- ase_md_kwargs["friction"] = ase_md_kwargs.get(
126
- "friction",
127
- 10.0 * 1e-3 / units.fs, # Same default as in VASP: 10 ps^-1
128
- )
129
-
130
- return ase_md_kwargs
131
-
132
-
133
-
134
-
135
- @task(cache_key_fn=task_input_hash, cache_expiration=timedelta(days=1))
136
- def md(
137
- atoms: Atoms,
138
- calculator_name: str | MLIPEnum,
139
- calculator_kwargs: dict | None,
140
- dispersion: str | None = None,
141
- dispersion_kwargs: dict | None = None,
142
- device: str | None = None,
143
- ensemble: Literal["nve", "nvt", "npt"] = "nvt",
144
- dynamics: str | MolecularDynamics = "langevin",
145
- time_step: float | None = None,
146
- total_time: float = 1000,
147
- temperature: float | Sequence | np.ndarray | None = 300.0,
148
- pressure: float | Sequence | np.ndarray | None = None,
149
- ase_md_kwargs: dict | None = None,
150
- mb_velocity_seed: int | None = None,
151
- zero_linear_momentum: bool = True,
152
- zero_angular_momentum: bool = True,
153
- traj_file: str | Path | None = None,
154
- traj_interval: int = 1,
155
- # ttime: float = 25 * units.fs,
156
- # pfactor: float = (75 * units.fs) ** 1 * units.GPa,
157
- # mask: np.ndarray | list[int] | None = None,
158
- # traceless: float = 1.0,
159
- restart: bool = True,
160
- # interval: int = 500,
161
- # device: str | None = None,
162
- # dtype: str = "float64",
163
- ):
164
- device = device or str(get_freer_device())
165
-
166
- print(f"Using device: {device}")
167
-
168
- calculator_kwargs = calculator_kwargs or {}
169
-
170
- if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
171
- assert issubclass(calculator_name.value, Calculator)
172
- calc = calculator_name.value(**calculator_kwargs)
173
- elif isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_:
174
- calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
175
- else:
176
- raise ValueError(f"Invalid calculator: {calculator_name}")
177
-
178
- print(f"Using calculator: {calc}")
179
-
180
- dispersion_kwargs = dispersion_kwargs or {}
181
-
182
- dispersion_kwargs.update({"device": device})
183
-
184
- if dispersion is not None:
185
- disp_calc = TorchDFTD3Calculator(
186
- **dispersion_kwargs,
187
- )
188
- calc = SumCalculator([calc, disp_calc])
189
-
190
- print(f"Using dispersion: {dispersion}")
191
-
192
- atoms.calc = calc
193
-
194
- if time_step is None:
195
- # If a structure contains an isotope of hydrogen, set default `time_step`
196
- # to 0.5 fs, and 2 fs otherwise.
197
- has_h_isotope = "H" in atoms.get_chemical_symbols()
198
- time_step = 0.5 if has_h_isotope else 2.0
199
-
200
- n_steps = int(total_time / time_step)
201
- target_steps = n_steps
202
-
203
- t_schedule, p_schedule = _get_ensemble_schedule(
204
- ensemble=ensemble,
205
- n_steps=n_steps,
206
- temperature=temperature,
207
- pressure=pressure,
208
- )
209
-
210
- ase_md_kwargs = _get_ensemble_defaults(
211
- ensemble=ensemble,
212
- dynamics=dynamics,
213
- t_schedule=t_schedule,
214
- p_schedule=p_schedule,
215
- ase_md_kwargs=ase_md_kwargs,
216
- )
217
-
218
- if isinstance(dynamics, str):
219
- # Use known dynamics if `self.dynamics` is a str
220
- dynamics = dynamics.lower()
221
- if dynamics not in _valid_dynamics[ensemble]:
222
- raise ValueError(
223
- f"{dynamics} thermostat not available for {ensemble}."
224
- f"Available {ensemble} thermostats are:"
225
- " ".join(_valid_dynamics[ensemble])
226
- )
227
-
228
- if ensemble == "nve" and dynamics is None:
229
- dynamics = "velocityverlet"
230
- md_class = _preset_dynamics[f"{ensemble}_{dynamics}"]
231
- elif issubclass(dynamics, MolecularDynamics):
232
- md_class = dynamics
233
-
234
- if md_class is NPT:
235
- # Note that until md_func is instantiated, isinstance(md_func,NPT) is False
236
- # ASE NPT implementation requires upper triangular cell
237
- u, _ = schur(atoms.get_cell(complete=True), output="complex")
238
- atoms.set_cell(u.real, scale_atoms=True)
239
-
240
- last_step = 0
241
-
242
- if traj_file is not None:
243
- traj_file = Path(traj_file)
244
-
245
- if restart and traj_file.exists():
246
- traj = read(traj_file, index=":")
247
- last_step = traj[-1].info.get("step", len(traj) * traj_interval)
248
- n_steps -= last_step
249
- last_atoms = traj[-1]
250
- traj = Trajectory(traj_file, "a", atoms)
251
- atoms.set_positions(last_atoms.get_positions())
252
- atoms.set_momenta(last_atoms.get_momenta())
253
- else:
254
- traj = Trajectory(traj_file, "w", atoms)
255
-
256
- if not np.isnan(t_schedule).any():
257
- MaxwellBoltzmannDistribution(
258
- atoms=atoms,
259
- temperature_K=t_schedule[last_step],
260
- rng=np.random.default_rng(seed=mb_velocity_seed),
261
- )
262
-
263
- if zero_linear_momentum:
264
- Stationary(atoms)
265
- if zero_angular_momentum:
266
- ZeroRotation(atoms)
267
-
268
- md_runner = md_class(
269
- atoms=atoms,
270
- timestep=time_step * units.fs,
271
- **ase_md_kwargs,
272
- )
273
-
274
- if traj_file is not None:
275
- md_runner.attach(traj.write, interval=traj_interval)
276
-
277
- with tqdm(total=n_steps) as pbar:
278
-
279
- def _callback(dyn: MolecularDynamics = md_runner) -> None:
280
- step = last_step + dyn.nsteps
281
- dyn.atoms.info["restart"] = last_step
282
- dyn.atoms.info["datetime"] = datetime.now()
283
- dyn.atoms.info["step"] = step
284
- dyn.atoms.info["target_steps"] = target_steps
285
- if ensemble == "nve":
286
- return
287
- dyn.set_temperature(temperature_K=t_schedule[step])
288
- if ensemble == "nvt":
289
- return
290
- dyn.set_stress(p_schedule[step] * 1e3 * units.bar)
291
- pbar.update()
292
-
293
- md_runner.attach(_callback, interval=1)
294
-
295
- start_time = datetime.now()
296
- md_runner.run(steps=n_steps)
297
- end_time = datetime.now()
298
-
299
- traj.close()
300
-
301
- return {"runtime": end_time - start_time, "n_steps": n_steps}