Source code for maxwelllink.mxl_drivers.python.models.ase_model

import os
import numpy as np
from typing import Optional, Sequence, Union, Dict
import ast

try:
    from .dummy_model import DummyModel
except:
    from dummy_model import DummyModel

try:
    from ase import units
except ImportError as e:
    raise ImportError(
        "ASE package is required for ASEModel. Please install it via 'conda install conda_forge::ase'."
    ) from e

from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from ase.md.verlet import VelocityVerlet
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution

# from ase.md.langevin import Langevin
from ase.io import read as ase_read

# units parameters
from maxwelllink.units import (
    FS_TO_AU,
    FORCE_PER_EFIELD_AU_EV_PER_ANG,
    BOHR_PER_ANG,
    AMU_TO_AU,
)


def _parse_kwargs_string(s: str) -> Dict:
    """
    Parse a compact ``'k1=v1,k2=v2'`` string into a dictionary with numbers/bools
    auto-cast. (Used for preset params and for passing kwargs to a user spec.)

    Parameters
    ----------
    s : str
        Input string.

    Returns
    -------
    dict
        Parsed key-value pairs with best-effort type casting.
    """

    if not s:
        return {}

    def _strip_quotes(s):
        if not isinstance(s, str):
            return s
        s = s.strip()
        if (len(s) >= 2) and ((s[0] == s[-1]) and s[0] in ("'", '"')):
            return s[1:-1]
        return s

    s = _strip_quotes(s)
    out = {}
    for token in s.split(","):
        if not token.strip():
            continue
        if "=" not in token:
            # allow bare flags as True
            out[token.strip()] = True
            continue
        k, v = token.split("=", 1)
        k = k.strip()
        v = v.strip()
        # try literal (int/float/bool/list/tuple) -> else string
        try:
            out[k] = ast.literal_eval(v)
        except Exception:
            # also allow "true"/"false"
            lv = v.lower()
            if lv in ("true", "false"):
                out[k] = lv == "true"
            else:
                out[k] = v
    return out


def _build_calculator(name: str, **kwargs) -> Calculator:
    """
    Minimal factory for common ASE calculators.

    Examples
    --------
    - ``name='psi4'`` → Psi4 via ``ase.calculators.psi4`` (Psi4 binary required)
    - ``name='dftb'`` → install DFTB+ (and set DFTB+ binary)
    - ``name='orca'`` → ORCA via ``ase.calculators.orca`` (ORCA binary required)

    Parameters
    ----------
    name : str
        Name of the calculator.
    **kwargs
        Additional keyword arguments passed to the calculator constructor.

    Returns
    -------
    ase.calculators.calculator.Calculator
        An ASE Calculator instance.

    Raises
    ------
    ImportError
        If the calculator cannot be constructed.
    """

    n = (name or "").strip().lower()

    if n == "psi4":
        from ase.calculators.psi4 import Psi4

        return Psi4(**kwargs)

    if n == "orca":
        from ase.calculators.orca import ORCA

        return ORCA(**kwargs)

    if n == "dftb":
        from ase.calculators.dftb import Dftb

        return Dftb(**kwargs)

    # Fallback: try to import path "ase.calculators.<name>"
    try:
        mod = __import__(f"ase.calculators.{n}", fromlist=["*"])
        # look for a Calculator subclass with same-cased name, else first subclass
        for attr in dir(mod):
            cls = getattr(mod, attr)
            try:
                if issubclass(cls, Calculator) and cls is not Calculator:
                    return cls(**kwargs)
            except Exception:
                pass
    except Exception as e:
        raise ImportError(
            f"Unknown calculator '{name}'. Install or extend _build_calculator()."
        ) from e

    raise ImportError(
        f"Failed to construct calculator for '{name}'. Extend _build_calculator()."
    )


[docs] class ForceAugmenter(Calculator): """ ASE Calculator wrapper that adds an external uniform E-field force :math:`F_i^{\\mathrm{ext}} = q_i \\mathbf{E}` to each atom :math:`i`, where :math:`q_i` are per-atom charges (either fixed or recomputed each step). """ implemented_properties = ("energy", "forces")
[docs] def __init__(self, base, charges=None, recompute_charges=False, verbose=False): """ Parameters ---------- base : ase.calculators.calculator.Calculator An ASE Calculator instance to wrap. charges : array-like or None, optional Per-atom charges in :math:`\\lvert e \\rvert` units. If ``None``, set ``recompute_charges=True``. recompute_charges : bool, default: False If ``True``, query charges each step (e.g., Mulliken). verbose : bool, default: False Whether to print verbose output. """ super().__init__() self.base = base self.charges = None if charges is None else np.asarray(charges, float).copy() self.recompute_charges = bool(recompute_charges) self.verbose = bool(verbose) self._E_au = np.zeros(3, float) # small caches for saving the computed results given a molecular geometry self._cache_key = None self._cache_energy = None self._cache_forces = None self._cache_dipole_vec = None
[docs] def set_field_au(self, Evec3_au): """ Set the external uniform E-field vector in atomic units (a.u.). Parameters ---------- Evec3_au : array-like of float, shape (3,) 3-element array-like representing the E-field vector in a.u. """ self._E_au = np.asarray(Evec3_au, float).reshape(3)
def _geom_key(self, atoms: Atoms): """ Build a key for the current molecular geometry. Parameters ---------- atoms : ase.Atoms An ASE ``Atoms`` object. Returns ------- tuple A hashable key representing the current geometry (positions, cell, numbers). """ # Positions + cell + numbers pos = atoms.get_positions() cell = atoms.cell.array Z = atoms.get_atomic_numbers() return ( pos.shape, pos.tobytes(), cell.shape, cell.tobytes(), Z.shape, Z.tobytes(), )
[docs] def calculation_required(self, atoms, properties): """ Determine whether a recalculation is required based on changes in the atomic configuration. Parameters ---------- atoms : ase.Atoms An ASE ``Atoms`` object. properties : tuple Properties to be calculated (e.g., ``'energy'``, ``'forces'``). Returns ------- bool Whether a recalculation is required. """ # If we have a cached forces/energy for the current molecular geometry, no recalc needed. key_now = self._geom_key(atoms) if self._cache_key is not None and key_now == self._cache_key: # We can satisfy any subset of ('energy','forces') from cache have_forces = self._cache_forces is not None have_energy = self._cache_energy is not None need_forces = "forces" in properties need_energy = "energy" in properties if (not need_forces or have_forces) and (not need_energy or have_energy): return False # Otherwise, defer to the base (positions/cell/numbers changes) -> recalc if hasattr(self.base, "calculation_required"): return self.base.calculation_required(atoms, properties) return super().calculation_required(atoms, properties)
[docs] def calculate_external_force(self, atoms): """ Calculate the external force on each atom due to the uniform E-field. Parameters ---------- atoms : ase.Atoms An ASE ``Atoms`` object. Returns ------- numpy.ndarray of float, shape (N, 3) External forces on each atom in eV/Å (ASE internal force units). """ if True: # Resolve charges if self.recompute_charges: q = None getq = getattr(self.base, "get_charges", None) if callable(getq): try: # print("[ForceAugmenter] calculate() charges called") q = np.asarray(getq(atoms), float) self.charges = q except Exception: q = None if q is None: if self.charges is None: raise RuntimeError( "recompute_charges=True but base has no get_charges(); pass fixed 'charges=' instead." ) q = self.charges else: if self.charges is None: raise RuntimeError( "Need 'charges=' or recompute_charges=True with a supporting calculator." ) q = self.charges # Add uniform-field force in eV/Angstrom Fext = q.reshape(-1, 1) * (self._E_au * FORCE_PER_EFIELD_AU_EV_PER_ANG) # after force calculation, also calculate dipole vector for later use self._cache_dipole_vec = ( q.reshape(1, -1) @ atoms.get_positions() ) * BOHR_PER_ANG return Fext
[docs] def calculate(self, atoms=None, properties=("energy",), system_changes=all_changes): """ Calculate the requested properties for the given atomic configuration. Parameters ---------- atoms : ase.Atoms, optional An ASE ``Atoms`` object. properties : tuple, default: ('energy',) Properties to be calculated (e.g., ``'energy'``, ``'forces'``). system_changes : list List of changes in the atomic configuration. """ key_now = self._geom_key(atoms) # If cache matches and covers requested properties, serve from cache and return if self._cache_key is not None and key_now == self._cache_key: if "energy" in properties and self._cache_energy is not None: self.results["energy"] = self._cache_energy if "forces" in properties and self._cache_forces is not None: self.results["forces"] = ( self._cache_forces + self.calculate_external_force(atoms) ) if (("energy" not in properties) or (self._cache_energy is not None)) and ( ("forces" not in properties) or (self._cache_forces is not None) ): return # Ask base ONLY for what was requested props_for_base = tuple( p for p in properties if p in getattr(self.base, "implemented_properties", ()) ) if not props_for_base: # Fallback: many base calculators accept energy-only props_for_base = ("energy",) if "energy" in properties else properties # Clear results for this call self.results.clear() # Compute on base self.base.calculate(atoms, props_for_base, system_changes) # Copy energy if requested if "energy" in properties and "energy" in self.base.results: self.results["energy"] = float(self.base.results["energy"]) elif "energy" in properties: # Not all calculators compute energy on a forces-only request self.results["energy"] = None # Forces: only if requested now if "forces" in properties: f_base = self.base.results.get("forces", None) if f_base is None: raise RuntimeError( "Base calculator did not provide forces when requested." ) f = np.array(f_base, dtype=float, copy=True) # Update caches self._cache_key = key_now self.results["energy"] = float(self.base.results["energy"]) self._cache_energy = self.results.get("energy", None) self._cache_forces = f.copy() Fext = self.calculate_external_force(atoms) f += Fext self.results["forces"] = f
[docs] class ASEModel(DummyModel): r""" General BOMD (Born-Oppenheimer MD) driver using ASE. This model provides the two key functionalities like other Models supported by MaxwellLink: 1. **MD coupled to E-field**: Injects :math:`F_i^{\\mathrm{ext}} = q_i \\mathbf{E}` (uniform field) to the molecular forces in MD simulations, where :math:`q_i` are per-atom charges (constant user-supplied or calculator-reported each step). 2. **Return source amplitude**: :math:`\\dot{\\boldsymbol{\\mu}} = \\sum_i q_i \\mathbf{v}_i` (converted to atomic units). """
[docs] def __init__( self, atoms: Union[str, Atoms], calculator: str = "psi4", calc_kwargs: str = "", charges: Optional[Sequence[float]] = None, recompute_charges: bool = False, temperature_K: float = 0.0, verbose: bool = False, checkpoint: bool = False, restart: bool = False, **extra, ): """ Parameters ---------- atoms : ase.Atoms or str Either an ASE ``Atoms`` object or a path to a structure file (e.g., ``.xyz``) readable by ASE. calculator : str, default: 'psi4' Name of ASE calculator (``'psi4'``, ``'dftb'``, ``'orca'``, ...). calc_kwargs : str, optional String of kwargs passed to the calculator constructor. charges : str or sequence of float or None, optional A string like ``"[-1.0 1.0]"`` representing an array of per-atom charges (in :math:`\\lvert e \\rvert`), separated by **space** (not **comma**). If ``None``, set ``recompute_charges=True``. recompute_charges : bool, default: False If ``True``, query charges each step (e.g., Mulliken). temperature_K : float, default: 0.0 Initial temperature in Kelvin for the Maxwell–Boltzmann distribution. verbose : bool, default: False Whether to print verbose output. checkpoint : bool, default: False Whether to enable checkpointing. restart : bool, default: False Whether to restart from a checkpoint if available. **extra Additional keyword arguments forwarded to the calculator constructor. """ super().__init__(verbose=verbose, checkpoint=checkpoint, restart=restart) # atoms if isinstance(atoms, Atoms): self.atoms = atoms.copy() else: # treat as file path self.atoms = ase_read(str(atoms)) self.calc_name = calculator # the input for calc_kwargs is as follows: --params 'xx=xx, calc_kwargs=k1=v1,k2=v2, yy=yy' # recover the first hit (k1=v1) self.calc_kwargs = _parse_kwargs_string(calc_kwargs) # all the other extra kwargs go into calc_kwargs (such as k2=v2) _own_keys = { "atoms", "calculator", "calc_kwargs", "charges", "recompute_charges", "n_substeps", "temperature_K", "verbose", "checkpoint", "restart", } for k in list(extra.keys()): if k not in _own_keys: self.calc_kwargs[k] = extra.pop(k) print("[ASEModel] calculator name =", self.calc_name) print("[ASEModel] calculator kwargs =", self.calc_kwargs) if charges is None: self.user_charges = None else: try: self.user_charges = np.fromstring(charges.strip("[]"), sep=" ") except Exception as e: raise ValueError( "Failed to parse 'charges' string into array; use format like '[0.1 -0.2 0.0 ...]'" ) from e self.recompute_charges = bool(recompute_charges) # now, let's only run MD for one time step at a time self.n_substeps = 1 self.temperature_K = temperature_K print("[ASEModel] user_charges =", self.user_charges) print("[ASEModel] recompute_charges =", self.recompute_charges) print("[ASEModel] n_substeps =", self.n_substeps) print("[ASEModel] temperature_K =", self.temperature_K) if self.user_charges is None and not self.recompute_charges: raise RuntimeError( "ASEModel needs charges: pass 'charges=' or set 'recompute_charges=True' with calculator support." ) self.integrator = None self.forcewrap = None self._last_amp = np.zeros(3) # cached for dmu/dt self._charges = None self._vel_angs_per_fs = None # dipole and dmudt info in previous time steps self.dipole_prev = None self.dmudt_prev = None self.dipole_middlepoint = None self.dmudt_middlepoint = None self.dipole_projected = None self.dmudt_projected = None self.kinEnuc = 0.0
# -------------- heavy-load initialization (at INIT) --------------
[docs] def initialize(self, dt_new, molecule_id): """ Initialize the model with the new time step and molecule ID. Parameters ---------- dt_new : float The new time step in atomic units (a.u.). molecule_id : int The ID of the molecule. """ self.dt = float(dt_new) self.molecule_id = int(molecule_id) self.checkpoint_filename = f"ase_checkpoint_id_{self.molecule_id}.npz" # MD step in fs dt_fs = (self.dt / FS_TO_AU) / self.n_substeps if dt_fs <= 0.0: raise ValueError("Non-positive dt_fs computed.") # Base calculator and force wrapper base_calc = _build_calculator(self.calc_name, **self.calc_kwargs) self.forcewrap = ForceAugmenter( base=base_calc, charges=self.user_charges, recompute_charges=self.recompute_charges, verbose=self.verbose, ) self.atoms.calc = self.forcewrap # Initialize velocities if self.temperature_K >= 0.0: MaxwellBoltzmannDistribution( self.atoms, temperature_K=float(self.temperature_K) ) if self.checkpoint and self.restart: self._reset_from_checkpoint() self.restarted = True # Choose the VelocityVerlet integrator self.integrator = VelocityVerlet( self.atoms, timestep=dt_fs * units.fs, logfile=None, loginterval=0 ) if self.verbose: print( f"[ASEModel {self.molecule_id}] dt={self.dt:.6e} a.u. " f"-> dt_fs={dt_fs:.6e} fs; substeps={self.n_substeps}; " f"calculator={self.calc_name}({self.calc_kwargs})" )
# -------------- one FDTD step under E-field --------------
[docs] def propagate(self, effective_efield_vec): """ Propagate the BO molecular dynamics given the effective electric field vector. Parameters ---------- effective_efield_vec : array-like of float, shape (3,) Effective electric field vector in the form ``[E_x, E_y, E_z]``. """ # 1. set the field for the wrapper (a.u.) self.E_vec = np.asarray(effective_efield_vec, float).reshape(3) if self.verbose: print( f"[ASEModel {self.molecule_id}] t={self.t:.6f} a.u., Efield={self.E_vec} a.u." ) self.forcewrap.set_field_au(self.E_vec) # 2. do n_substeps self.integrator.run(self.n_substeps) # 3. cache per-atom charges & velocities at the end of the step # charges (either recomputed or fixed) if self.recompute_charges: self._charges = self.forcewrap.charges else: self._charges = self.user_charges # velocities (Angstrom/fs) vel = self.atoms.get_velocities() if vel is None: vel = np.zeros((len(self.atoms), 3), float) self._vel_angs_per_fs = np.asarray(vel, float) # calculate kinetic energy under the atomic units kinetic_energy_au = 0.0 masses_amu = self.atoms.get_masses() for i in range(len(self.atoms)): vi_au = self._vel_angs_per_fs[i] * (BOHR_PER_ANG / FS_TO_AU) mi_au = masses_amu[i] * AMU_TO_AU kinetic_energy_au += 0.5 * mi_au * np.dot(vi_au, vi_au) self.kinEnuc = kinetic_energy_au if self.verbose: print( f"[ASEModel {self.molecule_id}] t={self.t:.6e} au, Efield_au={self.forcewrap._E_au} a.u.," f"q={self._charges}, v_angs_per_fs={self._vel_angs_per_fs}, energy_au={self.forcewrap._cache_energy + self.kinEnuc} a.u." f"energy_kin_au={self.kinEnuc} a.u." ) # advance model time in a.u. self.t += self.dt
[docs] def calc_amp_vector(self): r""" Return the amplitude vector :math:`\\mathrm{d}\\boldsymbol{\\mu}/\\mathrm{d}t` for the current time step in atomic units. In classical MD: :math:`\\displaystyle \\frac{\\mathrm{d}\\boldsymbol{\\mu}}{\\mathrm{d}t} = \\sum_i q_i \\mathbf{v}_i`. """ if self._vel_angs_per_fs is None or self._charges is None: return np.zeros(3, float) v_au = self._vel_angs_per_fs * (BOHR_PER_ANG / FS_TO_AU) amp = (self._charges.reshape(-1, 1) * v_au).sum(axis=0) # MaxwellLink sends E-field at time step n, expects amp at time step n+1/2. # However, with velocity verlet, if E-field is sent at time step n, then both velocity and position are updated to step n # at the final stage of velocity verlet. Therefore, we need to do a simple linear extrapolation here to get amp at time step n+1/2. self.dmudt_prev = ( self.dmudt_projected.copy() if self.dmudt_projected is not None else amp.copy() ) self.dmudt_middlepoint = amp.copy() self.dmudt_projected = 2.0 * self.dmudt_middlepoint - self.dmudt_prev return self.dmudt_projected
# ------------ optional operation / checkpoint --------------
[docs] def append_additional_data(self): """ Append additional data to be sent back to MaxwellLink. The data can be retrieved by the user via the Python interface: ``maxwelllink.SocketMolecule.additional_data_history``, where ``additional_data_history`` is a list of dictionaries. Returns ------- dict A dictionary containing additional data. """ # MaxwellLink sends E-field at time step n, expects amp at time step n+1/2. # However, with velocity verlet, if E-field is sent at time step n, then both velocity and position are updated to step n # at the final stage of velocity verlet. Therefore, we need to do a simple linear extrapolation here to get dipole at time step n+1/2. self.dipole_prev = ( self.dipole_projected.copy().reshape(-1) if self.dipole_projected is not None else self.forcewrap._cache_dipole_vec.copy().reshape(-1) ) self.dipole_middlepoint = self.forcewrap._cache_dipole_vec.copy().reshape(-1) self.dipole_projected = 2.0 * self.dipole_middlepoint - self.dipole_prev d = { "time_au": float(self.t), "energy_au": float( self.forcewrap._cache_energy + self.kinEnuc if self.forcewrap._cache_energy is not None else 0.0 ), "energy_kin_au": float(self.kinEnuc), "mux_au": float(self.dipole_projected[0]), "muy_au": float(self.dipole_projected[1]), "muz_au": float(self.dipole_projected[2]), "mux_m_au": float(self.dipole_middlepoint[0]), "muy_m_au": float(self.dipole_middlepoint[1]), "muz_m_au": float(self.dipole_middlepoint[2]), "temperature_K": float(self.atoms.get_temperature()), } return d
def _dump_to_checkpoint(self): """ Dump the internal state of the model to a checkpoint. """ np.savez( self.checkpoint_filename, t=self.t, positions=self.atoms.get_positions(), velocities=self.atoms.get_velocities(), ) def _reset_from_checkpoint(self): """ Reset the internal state of the model from a checkpoint. """ if not os.path.exists(self.checkpoint_filename): # No checkpoint file found means this driver has not been paused or terminated abnormally # so we just start fresh. if self.verbose: print( f"[ASEModel] No checkpoint file found for id={self.molecule_id}, starting fresh." ) else: data = np.load(self.checkpoint_filename) self.t = float(data["t"]) pos = np.asarray(data["positions"], float) vel = np.asarray(data["velocities"], float) self.atoms.set_positions(pos) self.atoms.set_velocities(vel) if self.verbose: print(f"[ASEModel] Restarted from checkpoint for id={self.molecule_id}") def _snapshot(self): """ Return a snapshot of the internal state for propagation. Notes ----- Deep copy the arrays to avoid mutation issues. Returns ------- dict A dictionary containing the snapshot of the internal state. """ return { "time": self.t, "positions": self.atoms.get_positions().copy(), "velocities": self.atoms.get_velocities().copy(), } def _restore(self, snapshot): """ Restore the internal state from a snapshot. Parameters ---------- snapshot : dict A dictionary containing the snapshot of the internal state. """ self.t = snapshot["time"] self.atoms.set_positions(snapshot["positions"]) self.atoms.set_velocities(snapshot["velocities"])