#!/usr/bin/env python3
from __future__ import annotations
import argparse
import subprocess
import shlex
import time
import shutil
import os
import socket, json
import numpy as np
import struct
try:
from .models import __drivers__
from .models.dummy_model import DummyModel
except ImportError:
from models import __drivers__
from models.dummy_model import DummyModel
description = """
A Python driver connecting to MaxwellLink, receiving E-field data and returning
the source amplitude vector for a quantum dynamics model.
"""
# helper function to determine whether this processor is the MPI master using mpi4py
def _am_master():
"""
Return True if this process is the MPI master rank (rank 0), otherwise False.
Notes
-----
Attempts to import ``mpi4py`` and query ``COMM_WORLD``. If ``mpi4py`` is not
available, returns ``True`` by treating the single process as rank 0.
"""
try:
from mpi4py import MPI as _MPI
_COMM = _MPI.COMM_WORLD
_RANK = _COMM.Get_rank()
except Exception:
_COMM = None
_RANK = 0
return _RANK == 0
_INT32 = struct.Struct("<i")
_FLOAT64 = struct.Struct("<d")
# numpy dtypes on the wire
DT_FLOAT = np.float64
DT_INT = np.int32
# header width (ASCII, space-padded)
HEADER_LEN = 12
# Message codes similar to i-PI's socket interface
STATUS = b"STATUS"
READY = b"READY"
HAVEDATA = b"HAVEDATA"
NEEDINIT = b"NEEDINIT"
INIT = b"INIT"
POSDATA = b"POSDATA"
GETFORCE = b"GETFORCE"
FORCEREADY = b"FORCEREADY"
STOP = b"STOP"
BYE = b"BYE"
# EM aliases for readability (same wire format)
FIELDDATA = POSDATA
GETSOURCE = GETFORCE
SOURCEREADY = FORCEREADY
class _SocketClosed(OSError):
pass
def _pad12(msg: bytes) -> bytes:
"""
Left-pad or right-pad a message to a fixed 12-byte ASCII header.
Parameters
----------
msg : bytes
Message bytes to send in the fixed-width header.
Returns
-------
bytes
The message padded with spaces to 12 bytes.
Raises
------
ValueError
If the message is longer than the 12-byte header.
"""
if len(msg) > HEADER_LEN:
raise ValueError("Header too long")
return msg.ljust(HEADER_LEN, b" ")
def _send_msg(sock: socket.socket, msg: bytes) -> None:
"""
Send a 12-byte ASCII header (space-padded).
Parameters
----------
sock : socket.socket
Connected socket.
msg : bytes
Message tag to send (e.g., ``b"STATUS"``).
"""
sock.sendall(_pad12(msg))
def _recvall(sock: socket.socket, n: int) -> bytes:
"""
Read exactly ``n`` bytes from a socket.
Parameters
----------
sock : socket.socket
Connected socket.
n : int
Number of bytes to read.
Returns
-------
bytes
The data read.
Raises
------
_SocketClosed
If the peer closes the connection before all bytes are received.
"""
buf = bytearray()
while len(buf) < n:
chunk = sock.recv(n - len(buf))
if not chunk:
raise _SocketClosed("Peer closed")
buf.extend(chunk)
return bytes(buf)
def _recv_msg(sock: socket.socket) -> bytes:
"""
Receive a 12-byte ASCII header and strip trailing spaces.
Parameters
----------
sock : socket.socket
Connected socket.
Returns
-------
bytes
The received header (without trailing spaces).
"""
hdr = _recvall(sock, HEADER_LEN)
return hdr.rstrip()
def _send_array(sock: socket.socket, arr, dtype) -> None:
"""
Send a NumPy array over a socket using a contiguous C-order memory view.
Parameters
----------
sock : socket.socket
Connected socket.
arr : array-like
Array data to send.
dtype : numpy.dtype
Data type to cast and send as (e.g., ``np.float64``).
"""
a = np.asarray(arr, dtype=dtype, order="C")
sock.sendall(memoryview(a).cast("B"))
def _recv_array(sock: socket.socket, shape, dtype):
"""
Receive a NumPy array of a given shape and dtype from a socket.
Parameters
----------
sock : socket.socket
Connected socket.
shape : tuple of int
Expected array shape.
dtype : numpy.dtype
Expected dtype (e.g., ``np.float64``).
Returns
-------
numpy.ndarray
The received array with the given shape and dtype.
Raises
------
_SocketClosed
If the peer closes the connection during the transfer.
"""
out = np.empty(shape, dtype=dtype, order="C")
mv = memoryview(out).cast("B")
need = mv.nbytes
got = 0
while got < need:
r = sock.recv_into(mv[got:], need - got)
if r == 0:
raise _SocketClosed("Peer closed")
got += r
return out
def _send_int(sock: socket.socket, x: int) -> None:
"""
Send a 32-bit little-endian integer.
Parameters
----------
sock : socket.socket
Connected socket.
x : int
Integer value to send.
"""
sock.sendall(_INT32.pack(int(x)))
def _recv_int(sock: socket.socket) -> int:
"""
Receive a 32-bit little-endian integer.
Parameters
----------
sock : socket.socket
Connected socket.
Returns
-------
int
The received integer.
Raises
------
_SocketClosed
If the peer closes the connection during the transfer.
"""
buf = bytearray(_INT32.size)
mv = memoryview(buf)
got = 0
while got < _INT32.size:
r = sock.recv_into(mv[got:], _INT32.size - got)
if r == 0:
raise _SocketClosed("Peer closed")
got += r
return _INT32.unpack(buf)[0]
def _send_bytes(sock: socket.socket, b: bytes) -> None:
"""
Send a length-prefixed byte string.
Parameters
----------
sock : socket.socket
Connected socket.
b : bytes
Byte string to send. The length is sent first as a 32-bit integer.
"""
_send_int(sock, len(b))
if len(b):
sock.sendall(b)
def _recv_bytes(sock: socket.socket) -> bytes:
"""
Receive a length-prefixed byte string.
Parameters
----------
sock : socket.socket
Connected socket.
Returns
-------
bytes
The received byte string (may be empty).
"""
n = _recv_int(sock)
return _recvall(sock, n) if n else b""
def _recv_posdata(sock: socket.socket):
"""
Read a POSDATA / FIELDDATA block from the socket.
Returns
-------
tuple
``(cell, icell, xyz)`` where:
- ``cell`` : ``(3, 3)`` ndarray (row-major), simulation cell.
- ``icell`` : ``(3, 3)`` ndarray (row-major), inverse cell.
- ``xyz`` : ``(nat, 3)`` ndarray of positions (or effective field payload).
"""
cell = _recv_array(sock, (3, 3), DT_FLOAT).T.copy()
icell = _recv_array(sock, (3, 3), DT_FLOAT).T.copy()
nat = _recv_int(sock)
xyz = _recv_array(sock, (nat, 3), DT_FLOAT)
return cell, icell, xyz
def _send_force_ready(
sock: socket.socket,
energy_ha: float,
forces_Nx3_ha_per_bohr,
virial_3x3_ha,
more: bytes = b"",
):
"""
Send a FORCEREADY/SOURCEREADY message with energy, forces, virial, and extras.
Parameters
----------
sock : socket.socket
Connected socket.
energy_ha : float
Total energy (Hartree).
forces_Nx3_ha_per_bohr : array-like
Forces as an ``(N, 3)`` array (Hartree/Bohr).
virial_3x3_ha : array-like
Virial tensor as a ``(3, 3)`` array (Hartree).
more : bytes, optional
Extra payload as bytes (length-prefixed), e.g., JSON metadata.
"""
_send_msg(sock, FORCEREADY)
_send_array(sock, np.array([energy_ha], dtype=DT_FLOAT), DT_FLOAT)
forces = np.asarray(forces_Nx3_ha_per_bohr, dtype=DT_FLOAT)
assert forces.ndim == 2 and forces.shape[1] == 3
_send_int(sock, forces.shape[0])
_send_array(sock, forces, DT_FLOAT)
_send_array(sock, np.asarray(virial_3x3_ha, dtype=DT_FLOAT).T, DT_FLOAT)
_send_bytes(sock, more)
# the above functions can be also obtained from maxwelllink.sockets,
# but we copy them here to avoid circular imports.
def _read_value(s):
"""
Attempt to parse a string as ``int`` or ``float``; fall back to string/boolean.
Parameters
----------
s : str
Input token.
Returns
-------
int or float or bool or str
Parsed value.
"""
s = s.strip()
for cast in (int, float):
try:
return cast(s)
except ValueError:
continue
if s.lower() == "false":
return False
if s.lower() == "true":
return True
return s
def _read_args_kwargs(input_str):
"""
Parse a comma-separated string into positional and keyword arguments.
Parameters
----------
input_str : str
Comma-separated tokens. Positional values are bare; keyword values use
``key=value``. Booleans accept ``true``/``false`` (case-insensitive).
Returns
-------
tuple
``(args, kwargs)`` where ``args`` is a list and ``kwargs`` is a dict.
"""
args = []
kwargs = {}
tokens = input_str.split(",")
for token in tokens:
token = token.strip()
if "=" in token:
key, value = token.split("=", 1)
kwargs[key.strip()] = _read_value(value)
elif len(token) > 0:
args.append(_read_value(token))
return args, kwargs
[docs]
def run_driver(
unix=False,
address="localhost",
port: int = 31415,
timeout: float = 600.0,
driver=DummyModel(),
sockets_prefix="/tmp/socketmxl_",
):
"""
Run the socket driver loop to communicate with MaxwellLink.
Parameters
----------
unix : bool, default: False
Use a UNIX domain socket if ``True``; otherwise use TCP/IP.
address : str, default: "localhost"
Hostname (TCP/IP) or UNIX socket name (when ``unix=True``).
port : int, default: 31415
TCP/IP port (ignored for UNIX sockets).
timeout : float, default: 600.0
Socket timeout in seconds.
driver : DummyModel, default: DummyModel()
Quantum dynamics model implementing the driver interface.
sockets_prefix : str, default: ``"/tmp/socketmxl_"``
Prefix for UNIX domain socket paths (ignored for TCP/IP).
Notes
-----
Implements a simple message protocol with headers such as ``STATUS``,
``INIT``, ``POSDATA``/``FIELDDATA``, ``GETFORCE``/``GETSOURCE``, and ``STOP``.
"""
if unix:
sock = socket.socket(socket.AF_UNIX)
sock.connect(sockets_prefix + address)
else:
sock = socket.socket(socket.AF_INET)
# NEW: keepalive + nodelay for low-latency, long-lived tiny messages
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except (OSError, AttributeError):
pass
sock.connect((address, port))
sock.settimeout(timeout)
initialized = False
have_result = False
pending_amp = None
additional_data = {}
dt_au = 0.0
molid = None
while True:
try:
msg = _recv_msg(sock)
except Exception:
# Treat EOF/timeouts during normal shutdown as clean exit
break
if msg == STATUS:
# Server is polling; we must reply with our state.
if not initialized:
_send_msg(sock, NEEDINIT)
elif have_result:
_send_msg(sock, HAVEDATA)
else:
_send_msg(sock, READY)
elif msg == INIT:
# Server sends INIT after we replied NEEDINIT
molid = _recv_int(sock)
init_json = json.loads(_recv_bytes(sock).decode("utf-8") or "{}")
dt_au = float(init_json.get("dt_au", 0.0))
print("[initialization] Time step in atomic units:", dt_au)
print("[initialization] Assigned a molecular ID:", molid)
driver.initialize(dt_au, molid)
initialized = True
print("[initialization] Finished initialization for molecular ID:", molid)
elif msg == FIELDDATA or msg == b"POSDATA":
# One step of data from server: treat "positions" as the E-field vector in a.u.
# This is to mirror i-pi's existing socket interface.
cell, icell, xyz = _recv_posdata(sock)
# effective [Ex, Ey, Ez] (a.u.) for this molecule
E = xyz[0]
# Stage the step (no commit)
driver.stage_step(E)
have_result = True
elif msg == GETSOURCE or msg == b"GETFORCE":
# Server asks us to return the result for this step
if not driver.have_result():
# it means the driver code was terminated during driver.propagate() and driver.calc_amp_vector()
# one way is to be defensive: return zero if we somehow got here without a computed result
pending_amp = np.zeros(3, float)
else:
pending_amp = driver.commit_step()
additional_data = driver.append_additional_data()
_send_force_ready(
sock,
energy_ha=0.0,
forces_Nx3_ha_per_bohr=pending_amp.reshape(1, 3),
virial_3x3_ha=np.zeros((3, 3)),
more=json.dumps(
additional_data,
ensure_ascii=False,
separators=(",", ":"),
sort_keys=True,
).encode("utf-8"),
)
have_result = False
pending_amp = None
elif msg == STOP:
# Acknowledge and leave gracefully
try:
_send_msg(sock, BYE)
finally:
print("Received STOP, exiting")
break
else:
raise RuntimeError(f"Unexpected header: {msg!r}")
[docs]
def mxl_driver_main():
"""
Parse CLI arguments and start the MaxwellLink socket driver.
Notes
-----
Constructs the selected model via ``__drivers__`` using the ``--model`` and
``--param`` options, then calls ``run_driver(...)``.
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-u",
"--unix",
action="store_true",
default=False,
help="Use a UNIX domain socket.",
)
parser.add_argument(
"-a",
"--address",
type=str,
default="localhost",
help="Host name (for INET sockets) or name of the UNIX domain socket to connect to.",
)
parser.add_argument(
"-S",
"--sockets_prefix",
type=str,
default="/tmp/socketmxl_",
help="Prefix used for the unix domain sockets. Ignored when using TCP/IP sockets.",
)
parser.add_argument(
"-p",
"--port",
type=int,
default=31415,
help="TCP/IP port number. Ignored when using UNIX domain sockets.",
)
parser.add_argument(
"-m",
"--model",
type=str,
default="dummy",
choices=list(__drivers__.keys()),
help="""Type of molecular / material model for computing dipole moments under EM field.
""",
)
parser.add_argument(
"-o",
"--param",
type=str,
default="",
help="""Parameters required to run the driver. Comma-separated list of values
""",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
default=False,
help="Verbose output.",
)
args = parser.parse_args()
driver_args, driver_kwargs = _read_args_kwargs(args.param)
if args.model in __drivers__:
try:
d_f = __drivers__[args.model](
*driver_args, verbose=args.verbose, **driver_kwargs
)
except ImportError:
# specific errors have already been triggered
raise
except Exception as err:
print(f"Error setting up molecular dynamics model {args.model}")
print(__drivers__[args.model].__doc__)
print("Error trace: ")
raise err
elif args.model == "dummy":
d_f = DummyModel(verbose=args.verbose)
else:
raise ValueError("Unsupported driver model ", args.model)
run_driver(
unix=args.unix,
address=args.address,
port=args.port,
driver=d_f,
sockets_prefix=args.sockets_prefix,
)
def _clean_env_for_subprocess():
"""
Return a copy of the environment with MPI-related variables removed.
Returns
-------
dict
Sanitized environment dictionary suitable for launching child processes.
"""
env = os.environ.copy()
# Nuke anything that makes a child think it's an MPI rank
prefixes = (
"PMI_",
"PMIX_",
"OMPI_",
"MPI_",
"MPICH_",
"I_MPI_",
"HYDRA_",
"SLURM_",
"FI_",
"UCX_",
"PSM2_",
"PMI",
)
for k in list(env.keys()):
for p in prefixes:
if k.startswith(p):
env.pop(k, None)
break
# Some MPIs set these exact names without a prefix
for k in ("PMI_FD", "PMI_PORT", "PMI_ID", "PMI_RANK", "PMI_SIZE"):
env.pop(k, None)
return env
[docs]
def launch_driver(
command='--model tls --port 31415 --param "omega=0.242, mu12=187, orientation=2, pe_initial=1e-4" --verbose',
sleep_time=0.5,
):
"""
Launch the driver as a background subprocess for local testing.
Parameters
----------
command : str, default: '--model tls --port 31415 --param "omega=0.242, mu12=187, orientation=2, pe_initial=1e-4" --verbose'
Command-line arguments passed to ``mxl_driver.py``.
sleep_time : float, default: 0.5
Time to sleep (seconds) after launch to allow initialization.
Returns
-------
subprocess.Popen or None
The process handle on the master rank, otherwise ``None``.
"""
if _am_master():
print(f"Launching driver with command: mxl_driver.py {command}")
# launch the external driver (client)
driver_argv = shlex.split(shutil.which("mxl_driver.py") + " " + command)
# Use a fresh, non-blocking subprocess; inherit env/stdio for easy debugging
proc = subprocess.Popen(driver_argv, env=_clean_env_for_subprocess())
time.sleep(sleep_time)
return proc
else:
return None
[docs]
def terminate_driver(proc, timeout=2.0):
"""
Terminate a driver process launched by ``launch_driver``.
Parameters
----------
proc : subprocess.Popen or None
Process handle to terminate.
timeout : float, default: 2.0
Seconds to wait for graceful shutdown before escalating.
"""
if proc is not None and _am_master():
# Give it a moment to shut down naturally after the sim closes the socket
try:
proc.wait(timeout=timeout)
except subprocess.TimeoutExpired:
proc.terminate()
print("Driver did not exit cleanly, sent terminate signal")
try:
proc.wait(timeout=timeout)
except subprocess.TimeoutExpired:
proc.kill()
print("Driver did not terminate, sent kill signal")
if __name__ == "__main__":
mxl_driver_main()