Spaces:
Running
Running
File size: 3,788 Bytes
1effaf5 52c1bfb 1effaf5 c7922c2 1effaf5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""
Define structure optimization tasks.
"""
from __future__ import annotations
from datetime import timedelta
from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.mixing import SumCalculator
from ase.filters import * # type: ignore
from ase.filters import Filter
from ase.optimize import * # type: ignore
from ase.optimize.optimize import Optimizer
from prefect import task
from prefect.tasks import task_input_hash
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
from mlip_arena.models import MLIPEnum
from mlip_arena.models.utils import get_freer_device
_valid_filters: dict[str, Filter] = {
"Filter": Filter,
"UnitCell": UnitCellFilter,
"ExpCell": ExpCellFilter,
"Strain": StrainFilter,
"FrechetCell": FrechetCellFilter,
} # type: ignore
_valid_optimizers: dict[str, Optimizer] = {
"MDMin": MDMin,
"FIRE": FIRE,
"LBFGS": LBFGS,
"LBFGSLineSearch": LBFGSLineSearch,
"BFGS": BFGS,
"BFGSLineSearch": BFGSLineSearch,
"QuasiNewton": QuasiNewton,
"GPMin": GPMin,
"CellAwareBFGS": CellAwareBFGS,
"ODE12r": ODE12r,
} # type: ignore
# @task(
# cache_key_fn=task_input_hash,
# cache_expiration=timedelta(days=1),
# timeout_seconds=120)
@task(timeout_seconds=120, result_storage=None)
def run(
atoms: Atoms,
calculator_name: str | MLIPEnum,
calculator_kwargs: dict | None,
dispersion: str | None = None,
dispersion_kwargs: dict | None = None,
device: str | None = None,
optimizer: Optimizer | str = BFGSLineSearch,
optimizer_kwargs: dict | None = None,
filter: Filter | str | None = None,
filter_kwargs: dict | None = None,
criterion: dict | None = None,
# TODO: fix symmetry
):
device = device or str(get_freer_device())
print(f"Using device: {device}")
calculator_kwargs = calculator_kwargs or {}
if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
assert issubclass(calculator_name.value, Calculator)
calc = calculator_name.value(**calculator_kwargs)
elif (
isinstance(calculator_name, str) and calculator_name in MLIPEnum._member_names_
):
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
else:
raise ValueError(f"Invalid calculator: {calculator_name}")
print(f"Using calculator: {calc}")
dispersion_kwargs = dispersion_kwargs or {}
dispersion_kwargs.update({"device": device})
if dispersion is not None:
disp_calc = TorchDFTD3Calculator(
**dispersion_kwargs,
)
calc = SumCalculator([calc, disp_calc])
print(f"Using dispersion: {dispersion}")
atoms.calc = calc
if isinstance(filter, str):
if filter not in _valid_filters:
raise ValueError(f"Invalid filter: {filter}")
filter = _valid_filters[filter]
if isinstance(optimizer, str):
if optimizer not in _valid_optimizers:
raise ValueError(f"Invalid optimizer: {optimizer}")
optimizer = _valid_optimizers[optimizer]
filter_kwargs = filter_kwargs or {}
optimizer_kwargs = optimizer_kwargs or {}
criterion = criterion or {}
if isinstance(filter, type) and issubclass(filter, Filter):
filter_instance = filter(atoms, **filter_kwargs)
print(f"Using filter: {filter_instance}")
optimizer_instance = optimizer(atoms, **optimizer_kwargs)
print(f"Using optimizer: {optimizer_instance}")
optimizer_instance.run(**criterion)
elif filter is None:
optimizer_instance = optimizer(atoms, **optimizer_kwargs)
print(f"Using optimizer: {optimizer_instance}")
optimizer_instance.run(**criterion)
return {
"atoms": atoms,
}
|