cyrusyc commited on
Commit
0ffedd3
·
1 Parent(s): 9d1a2a5

add task `stability` and md run script

Browse files
Files changed (1) hide show
  1. mlip_arena/tasks/stability/run.py +188 -1
mlip_arena/tasks/stability/run.py CHANGED
@@ -1,3 +1,190 @@
 
1
 
2
- from mlip_arena.tasks.utils import _valid_dynamics, _preset_dynamics
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
 
3
+ import datetime
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import Literal, Sequence
7
 
8
+ import numpy as np
9
+ import torch
10
+ from ase import Atoms, units
11
+ from ase.calculators.mixing import SumCalculator
12
+ from ase.io import read
13
+ from ase.io.trajectory import Trajectory
14
+ from ase.md.md import MolecularDynamics
15
+ from ase.md.npt import NPT
16
+ from ase.md.velocitydistribution import (
17
+ MaxwellBoltzmannDistribution,
18
+ Stationary,
19
+ ZeroRotation,
20
+ )
21
+ from scipy.linalg import schur
22
+ from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
23
+ from tqdm.auto import tqdm
24
+
25
+ from mlip_arena.models.utils import EXTMLIPEnum, MLIPMap, external_ase_calculator
26
+ from mlip_arena.tasks.utils import (
27
+ _get_ensemble_defaults,
28
+ _get_ensemble_schedule,
29
+ _preset_dynamics,
30
+ _valid_dynamics,
31
+ )
32
+
33
+
34
+ def md(
35
+ atoms: Atoms,
36
+ calculator_name: str | EXTMLIPEnum,
37
+ calculator_kwargs: dict | None,
38
+ dispersion: str | None = None,
39
+ dispersion_kwargs: dict | None = None,
40
+ device: str | None = None,
41
+ ensemble: Literal["nve", "nvt", "npt"] = "nvt",
42
+ dynamics: str | MolecularDynamics = "langevin",
43
+ time_step: float | None = None,
44
+ total_time: float = 1000,
45
+ temperature: float | Sequence | np.ndarray | None = 300.0,
46
+ pressure: float | Sequence | np.ndarray | None = None,
47
+ ase_md_kwargs: dict | None = None,
48
+ mb_velocity_seed: int | None = None,
49
+ zero_linear_momentum: bool = True,
50
+ zero_angular_momentum: bool = True,
51
+ traj_file: str | Path | None = None,
52
+ traj_interval: int = 1,
53
+ # ttime: float = 25 * units.fs,
54
+ # pfactor: float = (75 * units.fs) ** 1 * units.GPa,
55
+ # mask: np.ndarray | list[int] | None = None,
56
+ # traceless: float = 1.0,
57
+ restart: bool = True,
58
+ # interval: int = 500,
59
+ # device: str | None = None,
60
+ # dtype: str = "float64",
61
+ ):
62
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
+ print(f"Using device: {device}")
65
+
66
+ calculator_kwargs = calculator_kwargs or {}
67
+
68
+ if isinstance(calculator_name, EXTMLIPEnum) and calculator_name in EXTMLIPEnum:
69
+ calc = external_ase_calculator(calculator_name, **calculator_kwargs)
70
+ elif calculator_name in MLIPMap:
71
+ calc = MLIPMap[calculator_name](**calculator_kwargs)
72
+
73
+ print(f"Using calculator: {calc}")
74
+
75
+ dispersion_kwargs = dispersion_kwargs or {}
76
+
77
+ dispersion_kwargs.update({"device": device})
78
+
79
+ if dispersion is not None:
80
+ disp_calc = TorchDFTD3Calculator(
81
+ **dispersion_kwargs,
82
+ )
83
+ calc = SumCalculator([calc, disp_calc])
84
+
85
+ print(f"Using dispersion: {dispersion}")
86
+
87
+ atoms.calc = calc
88
+
89
+ if time_step is None:
90
+ # If a structure contains an isotope of hydrogen, set default `time_step`
91
+ # to 0.5 fs, and 2 fs otherwise.
92
+ has_h_isotope = "H" in atoms.get_chemical_symbols()
93
+ time_step = 0.5 if has_h_isotope else 2.0
94
+
95
+ n_steps = int(total_time / time_step)
96
+
97
+ t_schedule, p_schedule = _get_ensemble_schedule(
98
+ ensemble=ensemble,
99
+ n_steps=n_steps,
100
+ temperature=temperature,
101
+ pressure=pressure,
102
+ )
103
+
104
+ ase_md_kwargs = _get_ensemble_defaults(
105
+ ensemble=ensemble,
106
+ dynamics=dynamics,
107
+ t_schedule=t_schedule,
108
+ p_schedule=p_schedule,
109
+ ase_md_kwargs=ase_md_kwargs,
110
+ )
111
+
112
+ if isinstance(dynamics, str):
113
+ # Use known dynamics if `self.dynamics` is a str
114
+ dynamics = dynamics.lower()
115
+ if dynamics not in _valid_dynamics[ensemble]:
116
+ raise ValueError(
117
+ f"{dynamics} thermostat not available for {ensemble}."
118
+ f"Available {ensemble} thermostats are:"
119
+ " ".join(_valid_dynamics[ensemble])
120
+ )
121
+
122
+ if ensemble == "nve" and dynamics is None:
123
+ dynamics = "velocityverlet"
124
+ md_class = _preset_dynamics[f"{ensemble}_{dynamics}"]
125
+ elif issubclass(dynamics, MolecularDynamics):
126
+ md_class = dynamics
127
+
128
+ if md_class is NPT:
129
+ # Note that until md_func is instantiated, isinstance(md_func,NPT) is False
130
+ # ASE NPT implementation requires upper triangular cell
131
+ u, _ = schur(atoms.get_cell(complete=True), output="complex")
132
+ atoms.set_cell(u.real, scale_atoms=True)
133
+
134
+ last_step = 0
135
+
136
+ if traj_file is not None:
137
+ traj_file = Path(traj_file)
138
+
139
+ if restart and traj_file.exists():
140
+ traj = read(traj_file, index=":")
141
+ last_step = len(traj)
142
+ n_steps -= len(traj)
143
+ last_atoms = traj[-1]
144
+ traj = Trajectory(traj_file, "a", atoms)
145
+ atoms.set_positions(last_atoms.get_positions())
146
+ atoms.set_momenta(last_atoms.get_momenta())
147
+ else:
148
+ traj = Trajectory(traj_file, "w", atoms)
149
+
150
+ if not np.isnan(t_schedule).any():
151
+ MaxwellBoltzmannDistribution(
152
+ atoms=atoms,
153
+ temperature_K=t_schedule[last_step],
154
+ rng=np.random.default_rng(seed=mb_velocity_seed),
155
+ )
156
+
157
+ if zero_linear_momentum:
158
+ Stationary(atoms)
159
+ if zero_angular_momentum:
160
+ ZeroRotation(atoms)
161
+
162
+ md_runner = md_class(
163
+ atoms=atoms,
164
+ timestep=time_step * units.fs,
165
+ **ase_md_kwargs,
166
+ )
167
+
168
+ if traj_file is not None:
169
+ md_runner.attach(traj.write, interval=traj_interval)
170
+
171
+ with tqdm(total=n_steps) as pbar:
172
+
173
+ def _callback(dyn: MolecularDynamics = md_runner) -> None:
174
+ if ensemble == "nve":
175
+ return
176
+ dyn.set_temperature(temperature_K=t_schedule[last_step + dyn.nsteps])
177
+ if ensemble == "nvt":
178
+ return
179
+ dyn.set_stress(p_schedule[last_step + dyn.nsteps] * 1e3 * units.bar)
180
+ pbar.update()
181
+
182
+ md_runner.attach(_callback, interval=1)
183
+
184
+ start_time = datetime.now()
185
+ md_runner.run(steps=n_steps)
186
+ end_time = datetime.now()
187
+
188
+ traj.close()
189
+
190
+ return {"md_runtime": end_time - start_time}