{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No module named 'deepmd'\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "# from mp_api.client import MPRester\n", "from dask.distributed import Client\n", "from dask_jobqueue import SLURMCluster\n", "from prefect import task, flow\n", "from prefect.task_runners import ThreadPoolTaskRunner\n", "from prefect_dask import DaskTaskRunner\n", "from pymatgen.core.structure import Structure\n", "from dotenv import load_dotenv\n", "from ase import Atoms\n", "from ase.io import write, read\n", "from pathlib import Path\n", "import pandas as pd\n", "from prefect.futures import wait\n", "\n", "from mlip_arena.tasks.eos.run import fit as EOS\n", "from mlip_arena.models import MLIPEnum\n", "\n", "load_dotenv()\n", "\n", "# MP_API_KEY = os.environ.get(\"MP_API_KEY\", None)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MP Database version: 2023.11.1\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bb6c1969c89840888c556f8fa59b4a67", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Retrieving SummaryDoc documents: 0%| | 0/5135 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "with MPRester(MP_API_KEY) as mpr:\n", " print(\"MP Database version:\", mpr.get_database_version())\n", "\n", " summary_docs = mpr.materials.summary.search(\n", " num_elements=(1, 2),\n", " is_stable=True,\n", " fields=[\"material_id\", \"structure\", \"formula_pretty\"]\n", " )\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "\n", "atoms_list = []\n", "\n", "for doc in summary_docs:\n", "\n", " structure = doc.structure\n", " assert isinstance(structure, Structure)\n", "\n", " atoms = structure.to_ase_atoms()\n", "\n", " atoms_list.append(atoms)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "write(\"all.extxyz\", atoms_list)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [] }, "outputs": [], "source": [ "atoms_list = read(\"all.extxyz\", index=':')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "#!/bin/bash\n", "\n", "#SBATCH -A matgen\n", "#SBATCH --mem=0\n", "#SBATCH -t 00:30:00\n", "#SBATCH -N 1\n", "#SBATCH -G 4\n", "#SBATCH -q debug\n", "#SBATCH -C gpu\n", "#SBATCH -J eos\n", "source ~/.bashrc\n", "module load python\n", "source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena\n", "/pscratch/sd/c/cyrusyc/.conda/mlip-arena/bin/python -m distributed.cli.dask_worker tcp://128.55.64.49:36289 --name dummy-name --nthreads 1 --memory-limit 59.60GiB --nanny --death-timeout 60\n", "\n" ] } ], "source": [ "nodes_per_alloc = 1\n", "gpus_per_alloc = 4\n", "ntasks = 1\n", "\n", "cluster_kwargs = {\n", " \"cores\": 1,\n", " \"memory\": \"64 GB\",\n", " \"shebang\": \"#!/bin/bash\",\n", " \"account\": \"matgen\",\n", " \"walltime\": \"00:30:00\",\n", " \"job_mem\": \"0\",\n", " \"job_script_prologue\": [\n", " \"source ~/.bashrc\",\n", " \"module load python\",\n", " \"source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena\",\n", " ],\n", " \"job_directives_skip\": [\"-n\", \"--cpus-per-task\", \"-J\"],\n", " \"job_extra_directives\": [f\"-N {nodes_per_alloc}\", f\"-G {gpus_per_alloc}\", \"-q debug\", \"-C gpu\", \"-J eos\"],\n", "}\n", "cluster = SLURMCluster(**cluster_kwargs)\n", "\n", "print(cluster.job_script())\n", "cluster.adapt(minimum_jobs=2, maximum_jobs=2)\n", "client = Client(cluster)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [], "source": [ "from prefect.concurrency.sync import concurrency\n", "from prefect.runtime import flow_run, task_run\n", "\n", "def postprocess(output, model: str, formula: str):\n", " row = {\n", " \"formula\": formula,\n", " \"method\": model,\n", " \"volumes\": output[\"eos\"][\"volumes\"],\n", " \"energies\": output[\"eos\"][\"energies\"],\n", " \"K\": output[\"K\"],\n", " }\n", "\n", " fpath = Path(REGISTRY[model][\"family\"]) / f\"{model}.parquet\"\n", "\n", " if not fpath.exists():\n", " fpath.parent.mkdir(parents=True, exist_ok=True)\n", " df = pd.DataFrame([row]) # Convert the dictionary to a DataFrame with a list\n", " else:\n", " df = pd.read_parquet(fpath)\n", " new_row = pd.DataFrame([row]) # Convert dictionary to DataFrame with a list\n", " df = pd.concat([df, new_row], ignore_index=True)\n", "\n", " df.drop_duplicates(subset=[\"formula\", \"method\"], keep='last', inplace=True)\n", " df.to_parquet(fpath)\n", "\n", "\n", "\n", "task_runner = DaskTaskRunner(address=client.scheduler.address)\n", "EOS = EOS.with_options(\n", " # task_runner=task_runner, \n", " log_prints=True,\n", " timeout_seconds=120, \n", " # result_storage=None\n", ")\n", "\n", "from prefect import get_client\n", "\n", "async with get_client() as client:\n", " limit_id = await client.create_concurrency_limit(\n", " tag=\"bottleneck\", \n", " concurrency_limit=2\n", " )\n", "\n", "def generate_task_run_name():\n", " task_name = task_run.task_name\n", "\n", " parameters = task_run.parameters\n", "\n", " atoms = parameters[\"atoms\"]\n", " \n", " return f\"{task_name}: {atoms.get_chemical_formula()}\"\n", "\n", "@task(task_run_name=generate_task_run_name, tags=[\"bottleneck\"], timeout_seconds=150)\n", "def fit_one(atoms: Atoms, model: str):\n", " \n", " eos = EOS(\n", " atoms=atoms,\n", " calculator_name=model,\n", " calculator_kwargs={},\n", " device=None,\n", " optimizer=\"QuasiNewton\",\n", " optimizer_kwargs=None,\n", " filter=\"FrechetCell\",\n", " filter_kwargs=None,\n", " criterion=dict(\n", " fmax=0.1,\n", " ),\n", " max_abs_strain=0.1,\n", " npoints=7,\n", " )\n", " if isinstance(eos, dict):\n", " postprocess(output=eos, model=model, formula=atoms.get_chemical_formula())\n", " eos[\"method\"] = model\n", " \n", " return eos\n", " \n", "#https://docs-3.prefect.io/3.0/develop/task-runners#use-multiple-task-runners\n", "# @flow(task_runner=ThreadPoolTaskRunner(max_workers=50), log_prints=True)\n", "@flow(task_runner=task_runner, log_prints=True)\n", "def fit_all(atoms_list: list[Atoms]):\n", " \n", " futures = []\n", " for atoms in atoms_list:\n", " futures_per_atoms = []\n", " for model in MLIPEnum:\n", " \n", " # with concurrency(\"bottleneck\", occupy=2):\n", " future = fit_one.submit(atoms, model.name)\n", " # if not futures_per_atoms:\n", " # if not futures:\n", " # future = fit_one.submit(atoms, model.name)\n", " # else:\n", " # future = fit_one.submit(atoms, model.name, wait_for=[futures[-1]]) \n", " # else:\n", " # future = fit_one.submit(atoms, model.name, wait_for=[future])\n", " futures_per_atoms.append(future)\n", " \n", " futures.extend(futures_per_atoms)\n", "\n", " return [f.result() for f in futures]\n", "\n", "\n", "# @task(task_run_name=generate_task_run_name, result_storage=None)\n", "# def fit_one(atoms: Atoms):\n", " \n", "# outputs = []\n", "# for model in MLIPEnum:\n", "# try:\n", "# eos = EOS(\n", "# atoms=atoms,\n", "# calculator_name=model.name,\n", "# calculator_kwargs={},\n", "# device=None,\n", "# optimizer=\"QuasiNewton\",\n", "# optimizer_kwargs=None,\n", "# filter=\"FrechetCell\",\n", "# filter_kwargs=None,\n", "# criterion=dict(\n", "# fmax=0.1,\n", "# ),\n", "# max_abs_strain=0.1,\n", "# npoints=7,\n", "# )\n", "# if isinstance(eos, dict):\n", "# postprocess(output=eos, model=model.name, formula=atoms.get_chemical_formula())\n", "# eos[\"method\"] = model.name\n", "# outputs.append(eos)\n", "# except:\n", "# continue\n", " \n", "# return outputs\n", "\n", "# # https://orion-docs.prefect.io/latest/concepts/task-runners/#using-multiple-task-runners\n", "# @flow(task_runner=DaskTaskRunner(address=client.scheduler.address), log_prints=True, result_storage=None)\n", "# def fit_all(atoms_list: list[Atoms]):\n", " \n", "# futures = []\n", "# for atoms in atoms_list:\n", "# future = fit_one.submit(atoms)\n", "# futures.append(future)\n", " \n", "# wait(futures)\n", " \n", "# return [f.result(raise_on_failure=False) for f in futures]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
18:53:47.335 | INFO | prefect.engine - Created flow run 'vengeful-malkoha' for flow 'fit-all'\n", "\n" ], "text/plain": [ "18:53:47.335 | \u001b[36mINFO\u001b[0m | prefect.engine - Created flow run\u001b[35m 'vengeful-malkoha'\u001b[0m for flow\u001b[1;35m 'fit-all'\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
18:53:47.341 | INFO | prefect.engine - View at https://app.prefect.cloud/account/f7d40474-9362-4bfa-8950-ee6a43ec00f3/workspace/d4bb0913-5f5e-49f7-bfc5-06509088baeb/runs/flow-run/909d2bc4-695f-4eeb-8b7c-7660397a0692\n", "\n" ], "text/plain": [ "18:53:47.341 | \u001b[36mINFO\u001b[0m | prefect.engine - View at \u001b[94mhttps://app.prefect.cloud/account/f7d40474-9362-4bfa-8950-ee6a43ec00f3/workspace/d4bb0913-5f5e-49f7-bfc5-06509088baeb/runs/flow-run/909d2bc4-695f-4eeb-8b7c-7660397a0692\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
18:53:47.654 | INFO | prefect.task_runner.dask - Connecting to existing Dask cluster SLURMCluster(df8c3d55, 'tcp://128.55.64.49:36289', workers=0, threads=0, memory=0 B)\n",
"
\n"
],
"text/plain": [
"18:53:47.654 | \u001b[36mINFO\u001b[0m | prefect.task_runner.dask - Connecting to existing Dask cluster SLURMCluster(df8c3d55, 'tcp://128.55.64.49:36289', workers=0, threads=0, memory=0 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fit_all(atoms_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"Note that, because the DaskTaskRunner uses multiprocessing, calls to flows in scripts must be guarded with if __name__ == \"__main__\": or you will encounter warnings and errors.\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# import os\n",
"# import tempfile\n",
"# import shutil\n",
"# from contextlib import contextmanager\n",
"\n",
"# @contextmanager\n",
"# def twd():\n",
" \n",
"# pwd = os.getcwd()\n",
"# temp_dir = tempfile.mkdtemp()\n",
" \n",
"# try:\n",
"# os.chdir(temp_dir)\n",
"# yield\n",
"# finally:\n",
"# os.chdir(pwd)\n",
"# shutil.rmtree(temp_dir)\n",
"\n",
"# with twd():\n",
"\n",
"# fit_all(atoms_list)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.read_parquet('mace-mp/MACE-MP(M).parquet')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n", " | formula | \n", "method | \n", "volumes | \n", "energies | \n", "K | \n", "
---|---|---|---|---|---|
1 | \n", "Ac2O3 | \n", "MACE-MP(M) | \n", "[82.36010147441682, 85.41047560309894, 88.4608... | \n", "[-39.47541427612305, -39.65580749511719, -39.7... | \n", "95.755459 | \n", "
2 | \n", "Ac6In2 | \n", "MACE-MP(M) | \n", "[278.3036976131417, 288.61124196918433, 298.91... | \n", "[-31.21324348449707, -31.40914535522461, -31.5... | \n", "33.370214 | \n", "
3 | \n", "Ac6Tl2 | \n", "MACE-MP(M) | \n", "[278.30267000598286, 288.6101763025008, 298.91... | \n", "[-29.572534561157227, -29.833026885986328, -30... | \n", "29.065081 | \n", "
4 | \n", "Ac3Sn | \n", "MACE-MP(M) | \n", "[135.293532345587, 140.30440391394214, 145.315... | \n", "[-17.135194778442383, -17.228239059448242, -17... | \n", "30.622045 | \n", "
5 | \n", "AcAg | \n", "MACE-MP(M) | \n", "[55.376437498321394, 57.4274166649259, 59.4783... | \n", "[-7.274301528930664, -7.346108913421631, -7.39... | \n", "40.212164 | \n", "
6 | \n", "Ac4 | \n", "MACE-MP(M) | \n", "[166.09086069175856, 172.2423740507126, 178.39... | \n", "[-16.326059341430664, -16.406923294067383, -16... | \n", "25.409891 | \n", "
7 | \n", "Ac16S24 | \n", "MACE-MP(M) | \n", "[1006.5670668063424, 1043.84732853991, 1081.12... | \n", "[-249.4179229736328, -250.7970733642578, -251.... | \n", "61.734158 | \n", "