In [3]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
from ase import Atom, Atoms
from ase.data import chemical_symbols, covalent_radii, vdw_alvarez
from ase.io import read, write
from pymatgen.core import Element
from scipy import stats
from tqdm.auto import tqdm

from mlip_arena.models import MLIPEnum, REGISTRY

%matplotlib inline

In [7]:
for model in MLIPEnum:
    
    model_name = model.name
    
    if model_name != 'eqV2(OMat)':
        continue
    
    print(f"========== {model_name} ==========")

    calc = MLIPEnum[model_name].value()

    for symbol in tqdm(chemical_symbols[1:]):

        s = set([symbol])

        if "X" in s:
            continue

        try:
            atom = Atom(symbol)
            rmin = 0.9 * covalent_radii[atom.number]
            rvdw = vdw_alvarez.vdw_radii[atom.number] if atom.number < len(vdw_alvarez.vdw_radii) else np.nan
            rmax = 3.1 * rvdw if not np.isnan(rvdw) else 6
            rstep = 0.01

            a = 2 * rmax

            npts = int((rmax - rmin)/rstep)

            rs = np.linspace(rmin, rmax, npts)
            es = np.zeros_like(rs)

            da = symbol + symbol

            out_dir = Path(REGISTRY[model_name]["family"]) / str(da)
            os.makedirs(out_dir, exist_ok=True)

            skip = 0

            element = Element(symbol)

            try:
                m = element.valence[1]
                if element.valence == (0, 2):
                    m = 0
            except:
                m = 0


            r = rs[0]

            positions = [
                [a/2-r/2, a/2, a/2],
                [a/2+r/2, a/2, a/2],
            ]

            traj_fpath = out_dir / f"{model_name}.extxyz"

            if traj_fpath.exists():
                traj = read(traj_fpath, index=":")
                skip = len(traj)
                atoms = traj[-1]
            else:
                # Create the unit cell with two atoms
                atoms = Atoms(
                    da,
                    positions=positions,
                    # magmoms=magmoms,
                    cell=[a, a+0.001, a+0.002],
                    pbc=True
                )

            print(atoms)

            atoms.calc = calc

            for i, r in enumerate(tqdm(rs)):

                if i < skip:
                    continue

                positions = [
                    [a/2-r/2, a/2, a/2],
                    [a/2+r/2, a/2, a/2],
                ]

                # atoms.set_initial_magnetic_moments(magmoms)

                atoms.set_positions(positions)

                es[i] = atoms.get_potential_energy()

                write(traj_fpath, atoms, append="a")
        except Exception as e:
            print(e)






  0%|          | 0/118 [00:00<?, ?it/s]

Atoms(symbols='H2', pbc=True, cell=[7.4399999999999995, 7.441, 7.441999999999999], calculator=SinglePointCalculator(...))


  0%|          | 0/344 [00:00<?, ?it/s]

Atoms(symbols='He2', pbc=True, cell=[8.866, 8.866999999999999, 8.868], calculator=SinglePointCalculator(...))


  0%|          | 0/418 [00:00<?, ?it/s]

Atoms(symbols='Li2', pbc=True, cell=[13.144000000000002, 13.145000000000001, 13.146000000000003])


  0%|          | 0/542 [00:00<?, ?it/s]

Atoms(symbols='Be2', pbc=True, cell=[12.276, 12.277, 12.278])


  0%|          | 0/527 [00:00<?, ?it/s]

Atoms(symbols='B2', pbc=True, cell=[11.842, 11.843, 11.844000000000001])


  0%|          | 0/516 [00:00<?, ?it/s]

Atoms(symbols='C2', pbc=True, cell=[10.974, 10.975, 10.976])


  0%|          | 0/480 [00:00<?, ?it/s]

Atoms(symbols='N2', pbc=True, cell=[10.292, 10.293, 10.294])


  0%|          | 0/450 [00:00<?, ?it/s]

Atoms(symbols='O2', pbc=True, cell=[9.3, 9.301, 9.302000000000001])


  0%|          | 0/405 [00:00<?, ?it/s]

Atoms(symbols='F2', pbc=True, cell=[9.052, 9.052999999999999, 9.054])


  0%|          | 0/401 [00:00<?, ?it/s]

Atoms(symbols='Ne2', pbc=True, cell=[9.796000000000001, 9.797, 9.798000000000002])


  0%|          | 0/437 [00:00<?, ?it/s]

Atoms(symbols='Na2', pbc=True, cell=[15.5, 15.501, 15.502])


  0%|          | 0/625 [00:00<?, ?it/s]

Atoms(symbols='Mg2', pbc=True, cell=[15.562, 15.562999999999999, 15.564])


  0%|          | 0/651 [00:00<?, ?it/s]

Atoms(symbols='Al2', pbc=True, cell=[13.950000000000001, 13.951, 13.952000000000002])


  0%|          | 0/588 [00:00<?, ?it/s]

Atoms(symbols='Si2', pbc=True, cell=[13.578, 13.578999999999999, 13.58])


  0%|          | 0/578 [00:00<?, ?it/s]

Atoms(symbols='P2', pbc=True, cell=[11.78, 11.780999999999999, 11.782])


  0%|          | 0/492 [00:00<?, ?it/s]

Atoms(symbols='S2', pbc=True, cell=[11.718, 11.719, 11.72])


  0%|          | 0/491 [00:00<?, ?it/s]

Atoms(symbols='Cl2', pbc=True, cell=[11.284, 11.285, 11.286000000000001])


  0%|          | 0/472 [00:00<?, ?it/s]

Atoms(symbols='Ar2', pbc=True, cell=[11.346, 11.347, 11.348])


  0%|          | 0/471 [00:00<?, ?it/s]

Atoms(symbols='K2', pbc=True, cell=[16.926000000000002, 16.927000000000003, 16.928])


  0%|          | 0/663 [00:00<?, ?it/s]

Atoms(symbols='Ca2', pbc=True, cell=[16.244, 16.245, 16.246])


  0%|          | 0/653 [00:00<?, ?it/s]

Atoms(symbols='Sc2', pbc=True, cell=[15.996, 15.997, 15.998000000000001])


  0%|          | 0/646 [00:00<?, ?it/s]

Atoms(symbols='Ti2', pbc=True, cell=[15.252, 15.253, 15.254000000000001])


  0%|          | 0/618 [00:00<?, ?it/s]

Atoms(symbols='V2', pbc=True, cell=[15.004, 15.004999999999999, 15.006])


  0%|          | 0/612 [00:00<?, ?it/s]

Atoms(symbols='Cr2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002])


  0%|          | 0/634 [00:00<?, ?it/s]

Atoms(symbols='Mn2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002])


  0%|          | 0/634 [00:00<?, ?it/s]

Atoms(symbols='Fe2', pbc=True, cell=[15.128, 15.129, 15.13])


  0%|          | 0/637 [00:00<?, ?it/s]

Atoms(symbols='Co2', pbc=True, cell=[14.879999999999999, 14.880999999999998, 14.882])


  0%|          | 0/630 [00:00<?, ?it/s]

Atoms(symbols='Ni2', pbc=True, cell=[14.879999999999999, 14.880999999999998, 14.882])


  0%|          | 0/632 [00:00<?, ?it/s]

Atoms(symbols='Cu2', pbc=True, cell=[14.756, 14.757, 14.758000000000001])


  0%|          | 0/618 [00:00<?, ?it/s]

Atoms(symbols='Zn2', pbc=True, cell=[14.818000000000001, 14.819, 14.820000000000002])


  0%|          | 0/631 [00:00<?, ?it/s]

Atoms(symbols='Ga2', pbc=True, cell=[14.383999999999999, 14.384999999999998, 14.386])


  0%|          | 0/609 [00:00<?, ?it/s]

Atoms(symbols='Ge2', pbc=True, cell=[14.198, 14.199, 14.200000000000001])


  0%|          | 0/601 [00:00<?, ?it/s]

Atoms(symbols='As2', pbc=True, cell=[11.655999999999999, 11.656999999999998, 11.658])


  0%|          | 0/475 [00:00<?, ?it/s]

Atoms(symbols='Se2', pbc=True, cell=[11.284, 11.285, 11.286000000000001])


  0%|          | 0/456 [00:00<?, ?it/s]

Atoms(symbols='Br2', pbc=True, cell=[11.532000000000002, 11.533000000000001, 11.534000000000002])


  0%|          | 0/468 [00:00<?, ?it/s]

Atoms(symbols='Kr2', pbc=True, cell=[13.950000000000001, 13.951, 13.952000000000002])


  0%|          | 0/593 [00:00<?, ?it/s]

Atoms(symbols='Rb2', pbc=True, cell=[19.902, 19.903000000000002, 19.904])


  0%|          | 0/797 [00:00<?, ?it/s]

Atoms(symbols='Sr2', pbc=True, cell=[17.608, 17.609, 17.61])


  0%|          | 0/704 [00:00<?, ?it/s]

Atoms(symbols='Y2', pbc=True, cell=[17.05, 17.051000000000002, 17.052])


  0%|          | 0/681 [00:00<?, ?it/s]

Atoms(symbols='Zr2', pbc=True, cell=[15.624, 15.625, 15.626000000000001])


  0%|          | 0/623 [00:00<?, ?it/s]

Atoms(symbols='Nb2', pbc=True, cell=[15.872000000000002, 15.873000000000001, 15.874000000000002])


  0%|          | 0/646 [00:00<?, ?it/s]

Atoms(symbols='Mo2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002])


  0%|          | 0/620 [00:00<?, ?it/s]

Atoms(symbols='Tc2', pbc=True, cell=[15.128, 15.129, 15.13])


  0%|          | 0/624 [00:00<?, ?it/s]

Atoms(symbols='Ru2', pbc=True, cell=[15.252, 15.253, 15.254000000000001])


  0%|          | 0/631 [00:00<?, ?it/s]

Atoms(symbols='Rh2', pbc=True, cell=[15.128, 15.129, 15.13])


  0%|          | 0/628 [00:00<?, ?it/s]

Atoms(symbols='Pd2', pbc=True, cell=[13.33, 13.331, 13.332])


  0%|          | 0/541 [00:00<?, ?it/s]

Atoms(symbols='Ag2', pbc=True, cell=[15.686, 15.687, 15.688])


  0%|          | 0/653 [00:00<?, ?it/s]

Atoms(symbols='Cd2', pbc=True, cell=[15.438000000000002, 15.439000000000002, 15.440000000000003])


  0%|          | 0/642 [00:00<?, ?it/s]

Atoms(symbols='In2', pbc=True, cell=[15.066, 15.067, 15.068000000000001])


  0%|          | 0/625 [00:00<?, ?it/s]

Atoms(symbols='Sn2', pbc=True, cell=[15.004, 15.004999999999999, 15.006])


  0%|          | 0/625 [00:00<?, ?it/s]

Atoms(symbols='Sb2', pbc=True, cell=[15.314000000000002, 15.315000000000001, 15.316000000000003])


  0%|          | 0/640 [00:00<?, ?it/s]

Atoms(symbols='Te2', pbc=True, cell=[12.338000000000001, 12.339, 12.340000000000002])


  0%|          | 0/492 [00:00<?, ?it/s]

Atoms(symbols='I2', pbc=True, cell=[12.648000000000001, 12.649000000000001, 12.650000000000002])


  0%|          | 0/507 [00:00<?, ?it/s]

Atoms(symbols='Xe2', pbc=True, cell=[12.772, 12.773, 12.774000000000001])


  0%|          | 0/512 [00:00<?, ?it/s]

Atoms(symbols='Cs2', pbc=True, cell=[21.576, 21.577, 21.578])


  0%|          | 0/859 [00:00<?, ?it/s]

Atoms(symbols='Ba2', pbc=True, cell=[18.785999999999998, 18.787, 18.787999999999997])


  0%|          | 0/745 [00:00<?, ?it/s]

Atoms(symbols='La2', pbc=True, cell=[18.476, 18.477, 18.477999999999998])


  0%|          | 0/737 [00:00<?, ?it/s]

Atoms(symbols='Ce2', pbc=True, cell=[17.855999999999998, 17.857, 17.857999999999997])


  0%|          | 0/709 [00:00<?, ?it/s]

Atoms(symbols='Pr2', pbc=True, cell=[18.104, 18.105, 18.105999999999998])


  0%|          | 0/722 [00:00<?, ?it/s]

Atoms(symbols='Nd2', pbc=True, cell=[18.290000000000003, 18.291000000000004, 18.292])


  0%|          | 0/733 [00:00<?, ?it/s]

Atoms(symbols='Pm2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/420 [00:00<?, ?it/s]

Atoms(symbols='Sm2', pbc=True, cell=[17.98, 17.981, 17.982])


  0%|          | 0/720 [00:00<?, ?it/s]

Atoms(symbols='Eu2', pbc=True, cell=[17.794, 17.795, 17.796])


  0%|          | 0/711 [00:00<?, ?it/s]

Atoms(symbols='Gd2', pbc=True, cell=[17.546, 17.547, 17.548])


  0%|          | 0/700 [00:00<?, ?it/s]

Atoms(symbols='Tb2', pbc=True, cell=[17.298000000000002, 17.299000000000003, 17.3])


  0%|          | 0/690 [00:00<?, ?it/s]

Atoms(symbols='Dy2', pbc=True, cell=[17.794, 17.795, 17.796])


  0%|          | 0/716 [00:00<?, ?it/s]

Atoms(symbols='Ho2', pbc=True, cell=[17.422, 17.423000000000002, 17.424])


  0%|          | 0/698 [00:00<?, ?it/s]

Atoms(symbols='Er2', pbc=True, cell=[17.546, 17.547, 17.548])


  0%|          | 0/707 [00:00<?, ?it/s]

Atoms(symbols='Tm2', pbc=True, cell=[17.298000000000002, 17.299000000000003, 17.3])


  0%|          | 0/693 [00:00<?, ?it/s]

Atoms(symbols='Yb2', pbc=True, cell=[17.36, 17.361, 17.362])


  0%|          | 0/699 [00:00<?, ?it/s]

Atoms(symbols='Lu2', pbc=True, cell=[16.988000000000003, 16.989000000000004, 16.990000000000002])


  0%|          | 0/681 [00:00<?, ?it/s]

Atoms(symbols='Hf2', pbc=True, cell=[16.306, 16.307000000000002, 16.308])


  0%|          | 0/657 [00:00<?, ?it/s]

Atoms(symbols='Ta2', pbc=True, cell=[15.686, 15.687, 15.688])


  0%|          | 0/631 [00:00<?, ?it/s]

Atoms(symbols='W2', pbc=True, cell=[15.934, 15.934999999999999, 15.936])


  0%|          | 0/650 [00:00<?, ?it/s]

Atoms(symbols='Re2', pbc=True, cell=[15.438000000000002, 15.439000000000002, 15.440000000000003])


  0%|          | 0/636 [00:00<?, ?it/s]

Atoms(symbols='Os2', pbc=True, cell=[15.376, 15.376999999999999, 15.378])


  0%|          | 0/639 [00:00<?, ?it/s]

Atoms(symbols='Ir2', pbc=True, cell=[14.942000000000002, 14.943000000000001, 14.944000000000003])


  0%|          | 0/620 [00:00<?, ?it/s]

Atoms(symbols='Pt2', pbc=True, cell=[14.198, 14.199, 14.200000000000001])


  0%|          | 0/587 [00:00<?, ?it/s]

Atoms(symbols='Au2', pbc=True, cell=[14.383999999999999, 14.384999999999998, 14.386])


  0%|          | 0/596 [00:00<?, ?it/s]

Atoms(symbols='Hg2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002])


  0%|          | 0/640 [00:00<?, ?it/s]

Atoms(symbols='Tl2', pbc=True, cell=[15.314000000000002, 15.315000000000001, 15.316000000000003])


  0%|          | 0/635 [00:00<?, ?it/s]

Atoms(symbols='Pb2', pbc=True, cell=[16.12, 16.121000000000002, 16.122])


  0%|          | 0/674 [00:00<?, ?it/s]

Atoms(symbols='Bi2', pbc=True, cell=[15.748000000000001, 15.749, 15.750000000000002])


  0%|          | 0/654 [00:00<?, ?it/s]

Atoms(symbols='Po2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/474 [00:00<?, ?it/s]

Atoms(symbols='At2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/465 [00:00<?, ?it/s]

Atoms(symbols='Rn2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/465 [00:00<?, ?it/s]

Atoms(symbols='Fr2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/365 [00:00<?, ?it/s]

Atoms(symbols='Ra2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/401 [00:00<?, ?it/s]

Atoms(symbols='Ac2', pbc=True, cell=[17.36, 17.361, 17.362])


  0%|          | 0/674 [00:00<?, ?it/s]

Atoms(symbols='Th2', pbc=True, cell=[18.166, 18.167, 18.168])


  0%|          | 0/722 [00:00<?, ?it/s]

Atoms(symbols='Pa2', pbc=True, cell=[17.855999999999998, 17.857, 17.857999999999997])


  0%|          | 0/712 [00:00<?, ?it/s]

Atoms(symbols='U2', pbc=True, cell=[16.802, 16.803, 16.804])


  0%|          | 0/663 [00:00<?, ?it/s]

Atoms(symbols='Np2', pbc=True, cell=[17.483999999999998, 17.485, 17.485999999999997])


  0%|          | 0/703 [00:00<?, ?it/s]

Atoms(symbols='Pu2', pbc=True, cell=[17.422, 17.423000000000002, 17.424])


  0%|          | 0/702 [00:00<?, ?it/s]

Atoms(symbols='Am2', pbc=True, cell=[17.546, 17.547, 17.548])


  0%|          | 0/715 [00:00<?, ?it/s]

Atoms(symbols='Cm2', pbc=True, cell=[18.91, 18.911, 18.912])


  0%|          | 0/793 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Bk2', pbc=True, cell=[21.08, 21.081, 21.081999999999997])


  0%|          | 0/1036 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Cf2', pbc=True, cell=[18.91, 18.911, 18.912])


  0%|          | 0/927 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Es2', pbc=True, cell=[16.740000000000002, 16.741000000000003, 16.742])


  0%|          | 0/819 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Fm2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Md2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='No2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Lr2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Rf2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Db2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Sg2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Bh2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Hs2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Mt2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Ds2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Rg2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Cn2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Nh2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Fl2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Mc2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Lv2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Ts2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config
Atoms(symbols='Og2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

Atomic number exceeds that given in model config


In [8]:


for model in MLIPEnum:
    
    model_name = model.name
    
#     json_fpath = Path(REGISTRY[model_name]["family"]) / "homonuclear-diatomics.json"
    
#     if json_fpath.exists():
#         continue

    print(f"========== {model_name} ==========")
    
    df = pd.DataFrame(columns=[
        "name", 
        "method", 
        "R", "E", "F", "S^2", 
        "force-flip-times",
        "force-total-variation",
        "force-jump",
        "energy-diff-flip-times",
        "energy-grad-norm-max",
        "energy-jump",
        "energy-total-variation",
        "tortuosity",
        "conservation-deviation",
        "spearman-descending-force",
        "spearman-ascending-force",
        "spearman-repulsion-energy",
        "spearman-attraction-energy"
    ])
    

    for symbol in tqdm(chemical_symbols[1:]):

        da = symbol + symbol

        out_dir = Path(REGISTRY[model_name]["family"]) / da

        traj_fpath = out_dir / f"{model_name}.extxyz"


        if traj_fpath.exists():
            traj = read(traj_fpath, index=":")
        else:
            continue

        Rs, Es, Fs, S2s = [], [], [], []
        for atoms in traj:

            vec = atoms.positions[1] - atoms.positions[0]
            r = np.linalg.norm(vec)
            e = atoms.get_potential_energy()
            f = np.inner(vec/r, atoms.get_forces()[1])
            # s2 = np.mean(np.power(atoms.get_magnetic_moments(), 2))

            Rs.append(r)
            Es.append(e)
            Fs.append(f)
            # S2s.append(s2)

        rs = np.array(Rs)
        es = np.array(Es)
        fs = np.array(Fs)

        indices = np.argsort(rs)[::-1]
        rs = rs[indices]
        es = es[indices]
        eshift = es[0]
        es -= eshift
        fs = fs[indices]

        iminf = np.argmin(fs)
        imine = np.argmin(es)

        de_dr = np.gradient(es, rs)
        d2e_dr2 = np.gradient(de_dr, rs)

        rounded_fs = np.copy(fs)
        rounded_fs[np.abs(rounded_fs) < 1e-2] = 0
        fs_sign = np.sign(rounded_fs)
        mask = fs_sign != 0
        rounded_fs = rounded_fs[mask]
        fs_sign = fs_sign[mask]
        f_flip = np.diff(fs_sign) != 0
        
        fdiff = np.diff(fs)
        fdiff_sign = np.sign(fdiff)
        mask = fdiff_sign != 0
        fdiff = fdiff[mask]
        fdiff_sign = fdiff_sign[mask]
        fdiff_flip = np.diff(fdiff_sign) != 0
        fjump = np.abs(fdiff[:-1][fdiff_flip]).sum() + np.abs(fdiff[1:][fdiff_flip]).sum()
        

        ediff = np.diff(es)
        ediff[np.abs(ediff) < 1e-3] = 0
        ediff_sign = np.sign(ediff)
        mask = ediff_sign != 0
        ediff = ediff[mask]
        ediff_sign = ediff_sign[mask]
        ediff_flip = np.diff(ediff_sign) != 0
        ejump = np.abs(ediff[:-1][ediff_flip]).sum() + np.abs(ediff[1:][ediff_flip]).sum()
        
        
#         edged_es = np.convolve(es, [1, -2, 1], mode='valid')
#         # edged_es[np.abs(edged_es) < 0.1] = 0
#         prob = np.exp(-es[1:-1]) / np.sum(np.exp(-es[1:-1]))
#         edged_es *= prob
#         # edged_es /= np.abs(es[1:-1])
#         ejump = np.linalg.norm(edged_es)
#         ejump = np.abs(edged_es).sum() / 2.0
        
#         edged_fs = np.convolve(fs, [1, -2, 1], mode='valid')
#         # edged_fs[np.abs(edged_fs) < 0.1] = 0
#         edged_fs *= prob
#         fjump = np.linalg.norm(edged_fs)
        # fjump = np.abs(edged_fs).sum() / 2.0
        
#         fig, axes = plt.subplot_mosaic(
#             """
#             ac
#             bd
#             """,
#             constrained_layout=True
#         )
        

#         axes['a'].plot(rs, es)
#         axes['b'].plot(rs[1:-1], edged_es)
#         # axes['b'].plot(0.5*(rs[1:] + rs[:-1]), np.diff(es))
#         axes['b'].text(0.7, 0.7, f"{ejump:.3e}", transform=axes['b'].transAxes)
        
#         axes['c'].plot(rs, fs)
#         axes['d'].plot(rs[1:-1], edged_fs)
#         axes['d'].text(0.7, 0.7, f"{fjump:.3e}", transform=axes['d'].transAxes)
        

        conservation_deviation = np.mean(np.abs(fs + de_dr))
        
        etv = np.sum(np.abs(np.diff(es)))

        data = {
            "name": da,
            "method": model_name,
            "R": rs,
            "E": es + eshift,
            "F": fs,
            "S^2": S2s,
            "force-flip-times": np.sum(f_flip),
            "force-total-variation": np.sum(np.abs(np.diff(fs))),
            "force-jump": fjump,
            "energy-diff-flip-times": np.sum(ediff_flip),
            "energy-grad-norm-max": np.max(np.abs(de_dr)),
            "energy-jump": ejump,
            # "energy-grad-norm-mean": np.mean(de_dr_abs),
            "energy-total-variation": etv,
            "tortuosity": etv / (abs(es[0] - es.min()) + (es[-1] - es.min())),
            "conservation-deviation": conservation_deviation,
            "spearman-descending-force": stats.spearmanr(rs[iminf:], fs[iminf:]).statistic,
            "spearman-ascending-force": stats.spearmanr(rs[:iminf], fs[:iminf]).statistic,
            "spearman-repulsion-energy": stats.spearmanr(rs[imine:], es[imine:]).statistic,
            "spearman-attraction-energy": stats.spearmanr(rs[:imine], es[:imine]).statistic,
        }

        df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)

    json_fpath = Path(REGISTRY[model_name]["family"]) / "homonuclear-diatomics.json"

    if json_fpath.exists():
        df0 = pd.read_json(json_fpath)
        df = pd.concat([df0, df], ignore_index=True)
        df.drop_duplicates(inplace=True, subset=["name", "method"], keep='last')

    df.to_json(json_fpath, orient="records")



  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)


In [9]:
df

Unnamed: 0,name,method,R,E,F,S^2,force-flip-times,force-total-variation,force-jump,energy-diff-flip-times,energy-grad-norm-max,energy-jump,energy-total-variation,tortuosity,conservation-deviation,spearman-descending-force,spearman-ascending-force,spearman-repulsion-energy,spearman-attraction-energy
0,HH,eqV2(OMat),"[3.7199999999999998, 3.70996794, 3.69993586, 3...","[-2.0984511375427246, -2.095881462097168, -2.0...","[9.72e-06, 7.52e-06, 3.211e-05, 3.564e-05, 3.9...",[],2,106.606564,1.924082,17,93.241588,0.564488,18.840863,1.555587,2.842868,-0.994359,-0.072700,-0.992857,0.431210
1,HeHe,eqV2(OMat),"[4.433, 4.4229736200000005, 4.41294724, 4.4029...","[0.5383987426757812, 0.5361838340759277, 0.536...","[-1.365e-05, -4.945e-05, -8.538e-05, -0.000155...",[],2,265.816823,5.186691,14,77.250589,0.699341,19.451416,4.166659,10.740978,-0.973671,-0.546770,-0.912122,0.600920
2,LiLi,eqV2(OMat),"[6.572000000000001, 6.561981520000001, 6.55196...","[-0.5126352310180664, -0.5116205215454102, -0....","[-3.34e-05, 0.00016922, 0.00028867, 0.0005224,...",[],1,28.116090,0.053238,31,17.085735,0.181582,7.866485,1.334280,1.063989,-0.998354,0.976426,-0.981398,0.705693
3,BeBe,eqV2(OMat),"[6.138000000000001, 6.12797338, 6.117946759999...","[0.23909759521484375, 0.2400522232055664, 0.23...","[-4.817e-05, 0.00051783, 0.00090912, 0.0012907...",[],1,188.239996,0.177490,35,22.672879,0.328797,13.735467,1.275790,6.498375,-0.997688,0.983071,-0.972891,0.014737
4,BB,eqV2(OMat),"[5.921000000000001, 5.91097088, 5.900941739999...","[-1.388545036315918, -1.3869714736938477, -1.3...","[-7.486e-05, -0.00062429, -0.00124062, -0.0016...",[],1,142.788128,0.010691,19,30.771105,0.572540,13.649937,1.145869,4.867476,-1.000000,0.990445,-0.984379,0.993144
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
90,PaPa,eqV2(OMat),"[8.927999999999999, 8.91797468, 8.90794936, 8....","[-8.536531448364258, -8.537920951843262, -8.55...","[0.00149543, 0.00933722, 0.01594199, 0.0210447...",[],14,293.319609,0.622614,47,94.871632,1.276346,37.647346,3.460726,9.596265,-0.976682,0.916942,-1.000000,0.064439
91,UU,eqV2(OMat),"[8.401, 8.390974320000002, 8.38094864, 8.37092...","[-12.306029319763184, -12.309735298156738, -12...","[0.00250636, -0.00677688, -0.01450382, -0.0195...",[],9,308.464112,1.041690,53,74.002914,2.503158,36.593965,3.514434,9.661410,-0.853855,0.856809,-1.000000,0.009604
92,NpNp,eqV2(OMat),"[8.741999999999999, 8.7319829, 8.72196582, 8.7...","[-16.419757843017578, -16.415634155273438, -16...","[0.00010047, 0.00343876, 0.00583172, 0.0066036...",[],13,366.700701,2.488891,58,145.748637,3.159492,68.312851,6.846896,14.868774,-0.970815,0.783776,-1.000000,-0.092599
93,PuPu,eqV2(OMat),"[8.710999999999999, 8.70097432, 8.690948639999...","[-21.575969696044922, -21.568323135375977, -21...","[0.00021465, -0.018026, -0.02833581, -0.022387...",[],25,404.453349,20.662893,85,195.927256,14.399994,78.172195,5.575810,15.614091,-0.661178,0.348212,-0.002303,0.654905
