"""
Socket layer for MaxwellLink drivers and servers.
This module implements a lightweight socket protocol inspired by i-PI
(https://ipi-code.org/) and provides:
- **SocketHub**: a multi-client server/poller for coordinating many driver
connections with an FDTD engine.
- **Protocol constants**: ``STATUS``, ``READY``, ``HAVEDATA``, ``NEEDINIT``,
``INIT``, ...
- **EM aliases**: ``FIELDDATA``, ``GETSOURCE``, ``SOURCEREADY`` (1:1 mapping to
``POSDATA``/``GETFORCE``/``FORCEREADY``).
- **Low-level helpers**: ``_send_msg``, ``_recv_msg``, ``_send_array``/``_recv_array``,
etc.
- **Exceptions**: ``_SocketClosed``.
"""
from __future__ import annotations
import socket, struct, json, time, threading, os
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple
import numpy as np
_INT32 = struct.Struct("<i")
_FLOAT64 = struct.Struct("<d")
# Fixed header width (ASCII, space-padded)
HEADER_LEN = 12
# Canonical i-PI message codes
STATUS = b"STATUS"
READY = b"READY"
HAVEDATA = b"HAVEDATA"
NEEDINIT = b"NEEDINIT"
INIT = b"INIT"
POSDATA = b"POSDATA"
GETFORCE = b"GETFORCE"
FORCEREADY = b"FORCEREADY"
STOP = b"STOP"
BYE = b"BYE"
# EM aliases for readability (same wire format)
FIELDDATA = POSDATA
GETSOURCE = GETFORCE
SOURCEREADY = FORCEREADY
# numpy dtypes on the wire (i-PI/ASE use float64 for reals, int32 for counts)
DT_FLOAT = np.float64
DT_INT = np.int32
class _SocketClosed(OSError):
"""
Exception raised when the peer closes the socket unexpectedly.
"""
pass
def _pad12(msg: bytes) -> bytes:
"""
Pad a message to the fixed 12-byte ASCII header width.
Parameters
----------
msg : bytes
Message tag to send.
Returns
-------
bytes
Space-padded header of exactly 12 bytes.
Raises
------
ValueError
If ``msg`` exceeds the 12-byte header length.
"""
if len(msg) > HEADER_LEN:
raise ValueError("Header too long")
return msg.ljust(HEADER_LEN, b" ")
def _send_msg(sock: socket.socket, msg: bytes) -> None:
"""
Send a 12-byte ASCII header (space-padded).
Parameters
----------
sock : socket.socket
Connected socket.
msg : bytes
Message tag to send (e.g., ``b"STATUS"``).
"""
sock.sendall(_pad12(msg))
def _recvall(sock: socket.socket, n: int) -> bytes:
"""
Read exactly ``n`` bytes from a socket.
Parameters
----------
sock : socket.socket
Connected socket.
n : int
Number of bytes to read.
Returns
-------
bytes
The data read.
Raises
------
_SocketClosed
If the peer closes the connection before all bytes are received.
"""
"""Read exactly n bytes or raise _SocketClosed."""
buf = bytearray()
while len(buf) < n:
chunk = sock.recv(n - len(buf))
if not chunk:
raise _SocketClosed("Peer closed")
buf.extend(chunk)
return bytes(buf)
def _recv_msg(sock: socket.socket) -> bytes:
"""
Receive a 12-byte ASCII header.
Parameters
----------
sock : socket.socket
Connected socket.
Returns
-------
bytes
The received header without trailing spaces.
"""
"""Receive 12-byte ASCII header."""
hdr = _recvall(sock, HEADER_LEN)
return hdr.rstrip()
def _send_array(sock: socket.socket, arr, dtype) -> None:
"""
Send a NumPy array over a socket using a contiguous C-order memory view.
Parameters
----------
sock : socket.socket
Connected socket.
arr : array-like
Array data to send.
dtype : numpy.dtype
Data type to cast and send as (e.g., ``np.float64``).
"""
a = np.asarray(arr, dtype=dtype, order="C")
sock.sendall(memoryview(a).cast("B"))
def _recv_array(sock: socket.socket, shape, dtype):
"""
Receive a NumPy array of a given shape and dtype from a socket.
Parameters
----------
sock : socket.socket
Connected socket.
shape : tuple of int
Expected array shape.
dtype : numpy.dtype
Expected dtype (e.g., ``np.float64``).
Returns
-------
numpy.ndarray
The received array with the specified shape and dtype.
Raises
------
_SocketClosed
If the peer closes the connection during the transfer.
"""
out = np.empty(shape, dtype=dtype, order="C")
mv = memoryview(out).cast("B")
need = mv.nbytes
got = 0
while got < need:
r = sock.recv_into(mv[got:], need - got)
if r == 0:
raise _SocketClosed("Peer closed")
got += r
return out
def _send_int(sock: socket.socket, x: int) -> None:
"""
Send a 32-bit little-endian integer.
Parameters
----------
sock : socket.socket
Connected socket.
x : int
Integer value to send.
"""
sock.sendall(_INT32.pack(int(x)))
def _recv_int(sock: socket.socket) -> int:
"""
Receive a 32-bit little-endian integer.
Parameters
----------
sock : socket.socket
Connected socket.
Returns
-------
int
The received integer.
Raises
------
_SocketClosed
If the peer closes the connection during the transfer.
"""
buf = bytearray(_INT32.size)
mv = memoryview(buf)
got = 0
while got < _INT32.size:
r = sock.recv_into(mv[got:], _INT32.size - got)
if r == 0:
raise _SocketClosed("Peer closed")
got += r
return _INT32.unpack(buf)[0]
def _send_bytes(sock: socket.socket, b: bytes) -> None:
"""
Send a length-prefixed byte string.
Parameters
----------
sock : socket.socket
Connected socket.
b : bytes
Byte string to send. The length is sent first as a 32-bit integer.
"""
_send_int(sock, len(b))
if len(b):
sock.sendall(b)
def _recv_bytes(sock: socket.socket) -> bytes:
"""
Receive a length-prefixed byte string.
Parameters
----------
sock : socket.socket
Connected socket.
Returns
-------
bytes
The received byte string (may be empty).
"""
n = _recv_int(sock)
return _recvall(sock, n) if n else b""
# -------- compound payloads (i-PI compatible) --------
def _send_posdata(
sock: socket.socket, cell_3x3_bohr, invcell_3x3_per_bohr, positions_Nx3_bohr
):
"""
Send a POSDATA/FIELDDATA block: cell, inverse cell, natoms, positions.
Parameters
----------
sock : socket.socket
Connected socket.
cell_3x3_bohr : array-like, shape (3, 3)
Lattice vectors (Bohr).
invcell_3x3_per_bohr : array-like, shape (3, 3)
Inverse lattice (1/Bohr).
positions_Nx3_bohr : array-like, shape (N, 3)
Atomic positions (Bohr).
Notes
-----
For EM use, this is also used to carry field vectors via the positions payload.
"""
assert np.asarray(cell_3x3_bohr).shape == (3, 3)
assert np.asarray(invcell_3x3_per_bohr).shape == (3, 3)
pos = np.asarray(positions_Nx3_bohr, dtype=DT_FLOAT)
assert pos.ndim == 2 and pos.shape[1] == 3
_send_msg(sock, POSDATA)
_send_array(sock, np.asarray(cell_3x3_bohr, dtype=DT_FLOAT).T, DT_FLOAT)
_send_array(sock, np.asarray(invcell_3x3_per_bohr, dtype=DT_FLOAT).T, DT_FLOAT)
_send_int(sock, pos.shape[0])
_send_array(sock, pos, DT_FLOAT)
def _recv_posdata(sock: socket.socket):
"""
Read a POSDATA/FIELDDATA block.
Parameters
----------
sock : socket.socket
Connected socket.
Returns
-------
tuple
``(cell, icell, xyz)`` where:
- ``cell`` : ``(3, 3)`` ndarray (row-major), simulation cell.
- ``icell`` : ``(3, 3)`` ndarray (row-major), inverse cell.
- ``xyz`` : ``(nat, 3)`` ndarray of positions (or effective field payload).
"""
cell = _recv_array(sock, (3, 3), DT_FLOAT).T.copy()
icell = _recv_array(sock, (3, 3), DT_FLOAT).T.copy()
nat = _recv_int(sock)
xyz = _recv_array(sock, (nat, 3), DT_FLOAT)
return cell, icell, xyz
def _send_force_ready(
sock: socket.socket,
energy_ha: float,
forces_Nx3_ha_per_bohr,
virial_3x3_ha,
more: bytes = b"",
):
"""
Send a FORCEREADY/SOURCEREADY message with energy, forces, virial, and extras.
Parameters
----------
sock : socket.socket
Connected socket.
energy_ha : float
Total energy (Hartree).
forces_Nx3_ha_per_bohr : array-like, shape (N, 3)
Forces (Hartree/Bohr).
virial_3x3_ha : array-like, shape (3, 3)
Virial tensor (Hartree).
more : bytes, optional
Extra payload (length-prefixed), e.g., JSON metadata.
"""
_send_msg(sock, FORCEREADY)
_send_array(sock, np.array([energy_ha], dtype=DT_FLOAT), DT_FLOAT)
forces = np.asarray(forces_Nx3_ha_per_bohr, dtype=DT_FLOAT)
assert forces.ndim == 2 and forces.shape[1] == 3
_send_int(sock, forces.shape[0])
_send_array(sock, forces, DT_FLOAT)
_send_array(sock, np.asarray(virial_3x3_ha, dtype=DT_FLOAT).T, DT_FLOAT)
_send_bytes(sock, more)
def _recv_getforce(sock: socket.socket):
"""
Receive a FORCEREADY/SOURCEREADY payload after a GETFORCE/GETSOURCE request.
Parameters
----------
sock : socket.socket
Connected socket.
Returns
-------
tuple
``(energy, forces, virial, extra)`` where ``energy`` is a float, ``forces``
is an ``(N, 3)`` ndarray, ``virial`` is a ``(3, 3)`` ndarray, and
``extra`` is raw bytes.
"""
e = float(_recv_array(sock, (1,), DT_FLOAT)[0])
nat = _recv_int(sock)
frcs = _recv_array(sock, (nat, 3), DT_FLOAT)
vir = _recv_array(sock, (3, 3), DT_FLOAT).T.copy()
extra = _recv_bytes(sock)
return e, frcs, vir, extra
# -------- convenience wrappers for EM (i-PI compatible) --------
def _pack_em_fieldata(
sock: socket.socket, t_au: float, dt_au: float, efield_au_vec3, meta: dict
):
"""
Send EM field data encoded as POSDATA with ``natoms=1`` and
``positions = [E_x, E_y, E_z]``.
Parameters
----------
sock : socket.socket
Connected socket.
t_au : float
Current time (a.u.). (Informational; not transmitted in POSDATA.)
dt_au : float
Time step (a.u.). (Informational; not transmitted in POSDATA.)
efield_au_vec3 : array-like, shape (3,)
Electric field vector ``[E_x, E_y, E_z]`` in a.u.
meta : dict
Optional metadata carried alongside in higher-level protocols.
"""
I = np.eye(3, dtype=DT_FLOAT)
exyz = np.asarray(efield_au_vec3, dtype=DT_FLOAT).reshape(1, 3)
_send_posdata(sock, I, I, exyz)
# meta/time tags can be sent back in SOURCEREADY's extra blob if needed.
def _pack_init(sock: socket.socket, init_dict: dict):
"""
Send an INIT handshake containing a JSON payload.
Parameters
----------
sock : socket.socket
Connected socket.
init_dict : dict
Initialization dictionary (e.g., includes ``"molecule_id"``).
"""
_send_msg(sock, INIT)
molid = int(init_dict.get("molecule_id", 0))
_send_int(sock, molid)
init_bytes = json.dumps(init_dict).encode("utf-8")
_send_bytes(sock, init_bytes)
@dataclass
class _ClientState:
"""
Dataclass storing per-client state for the socket hub.
Attributes
----------
sock : socket.socket
Connected client socket.
address : str
Peer address string.
molecule_id : int
Bound molecule identifier (``-1`` if unbound).
last_amp : numpy.ndarray or None
Last source amplitude vector ``(3,)``.
pending_send : bool
Whether a field has been dispatched but not yet committed.
initialized : bool
Whether INIT has been completed.
alive : bool
Connection liveness flag.
extras : dict
Arbitrary metadata associated with the client.
"""
sock: socket.socket
address: str
molecule_id: int
last_amp: Optional[np.ndarray] = None # last source amplitude (3,)
pending_send: bool = False
initialized: bool = False
alive: bool = True
extras: dict = field(default_factory=dict)
[docs]
def get_available_host_port(localhost=True, save_to_file=None) -> Tuple[str, int]:
"""
Ask the OS for an available localhost TCP port.
Parameters
----------
localhost : bool, default: True
If True, bind to the localhost interface ("127.0.0.1"). If False, bind to all interfaces ("0.0.0.0").
save_to_file : str or None, default: None
If provided, save the selected host and port to the given file with filename provided by `save_to_file`.
The first line contains the host, and the second line contains the port.
Returns
-------
tuple
``(host, port)`` pair, e.g., ``("127.0.0.1", 34567)``.
"""
bind_addr = "127.0.0.1" if localhost else "0.0.0.0"
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((bind_addr, 0))
port = s.getsockname()[1]
ip = "127.0.0.1"
if not localhost:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as tmp:
tmp.connect(("8.8.8.8", 80))
ip = tmp.getsockname()[0]
if am_master():
# save host and port number to a file so mxl_driver can read it
if save_to_file is not None:
with open(save_to_file, "w") as f:
f.write(f"{ip}\n{port}\n")
return ip, port
# helper function to determine whether this processor is the MPI master using mpi4py
[docs]
def am_master():
"""
Return True if this process is the MPI master rank (rank 0), otherwise False.
Notes
-----
Attempts to import ``mpi4py`` and query ``COMM_WORLD``. If unavailable,
returns ``True`` by treating the single process as rank 0.
"""
try:
from mpi4py import MPI as _MPI
_COMM = _MPI.COMM_WORLD
_RANK = _COMM.Get_rank()
except Exception:
_COMM = None
_RANK = 0
return _RANK == 0
# helper function to broadcast a value from master to all MPI ranks
[docs]
def mpi_bcast_from_master(value):
"""
Broadcast a Python value from the master rank to all ranks via MPI.
Parameters
----------
value : any
The value to broadcast.
Returns
-------
any
The broadcast value (unchanged when MPI is unavailable).
"""
try:
from mpi4py import MPI as _MPI
_COMM = _MPI.COMM_WORLD
except Exception:
_COMM = None
if _COMM is not None:
value = _COMM.bcast(value, root=0)
return value
[docs]
class SocketHub:
"""
Socket server coordinating multiple driver connections with an FDTD engine.
This server:
- Accepts and tracks many driver connections.
- Handles initialization handshakes, field dispatch, and result collection.
- Provides a barrier-style step to send fields and receive source amplitudes
from all registered molecules.
"""
[docs]
def __init__(
self,
host: Optional[str] = None,
port: Optional[int] = 31415,
unixsocket: Optional[str] = None,
timeout: float = 60.0,
latency: float = 0.01,
):
"""
Initialize the socket hub.
Parameters
----------
host : str or None, default: None
Host address for AF_INET sockets. Ignored when using a UNIX socket.
port : int or None, default: 31415
TCP port for AF_INET sockets. Ignored for UNIX sockets.
unixsocket : str or None, default: None
Path (or name under ``/tmp/socketmxl_*``) for a UNIX domain socket. When
provided, ``host`` and ``port`` are ignored.
timeout : float, default: 60.0
Socket timeout (seconds) for client operations.
latency : float, default: 0.01
Polling sleep (seconds) between hub sweeps; can be very small for local runs.
"""
self.unixsocket_path = None
if am_master():
if unixsocket:
self.serversock = socket.socket(socket.AF_UNIX)
# mirror i-PI's /tmp/ipi_* default when given a name
if not unixsocket.startswith("/"):
unixsocket = f"/tmp/socketmxl_{unixsocket}"
self.unixsocket_path = unixsocket
if os.path.exists(self.unixsocket_path):
probe = socket.socket(socket.AF_UNIX)
try:
probe.settimeout(0.25)
probe.connect(self.unixsocket_path)
except FileNotFoundError:
pass
except ConnectionRefusedError:
try:
os.unlink(self.unixsocket_path)
except FileNotFoundError:
pass
else:
probe.close()
raise RuntimeError(
f"Socket path {self.unixsocket_path} already in use"
)
finally:
try:
probe.close()
except Exception:
pass
self.serversock.bind(unixsocket)
self._where = unixsocket
else:
self.serversock = socket.socket(socket.AF_INET)
self.serversock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
host = host or ""
port = port or 31415
self.serversock.bind((host, port))
self._where = f"{host}:{port}"
self.serversock.listen(4096)
self.serversock.settimeout(0.25)
self.timeout = float(timeout)
self.latency = float(latency)
# key: molecule_id or temp id
self.clients: Dict[int, _ClientState] = {}
# peer -> molecule_id
self.addrmap: Dict[str, int] = {}
self._stop = False
self._lock = threading.RLock()
self._accept_th = threading.Thread(target=self._accept_loop, daemon=True)
self._accept_th.start()
# assign a molecular id accumulator
self._molecule_id_counter = 0
# molecule_id -> _ClientState (locked client)
self.bound: Dict[int, _ClientState] = {}
# molecule ids we expect to serve
self.expected: set[int] = set()
# global pause when any driver is down
self.paused = False
# holds a frozen barrier until it successfully commits
self._inflight = None
def _accept_loop(self):
"""
Accept-loop thread: accept new connections and register temporary clients.
"""
while not self._stop:
try:
csock, addr = self.serversock.accept()
except socket.timeout:
continue
except OSError:
break
# NEW: trim latency and keep long-lived connections healthy
try:
csock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# Only for AF_INET; will raise on AF_UNIX
csock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except (OSError, AttributeError):
pass # AF_UNIX or platform without TCP_NODELAY
peer = addr if isinstance(addr, str) else f"{addr[0]}:{addr[1]}"
csock.settimeout(self.timeout)
st = _ClientState(sock=csock, address=peer, molecule_id=-1)
with self._lock:
# temp key: use id(csock) until INIT binds molecule_id
self.clients[id(csock)] = st
def _maybe_init_client(self, st: _ClientState, init_payload: dict):
"""
Send INIT to a client with the given payload and mark it initialized.
Parameters
----------
st : _ClientState
Client state to initialize.
init_payload : dict
Initialization payload (e.g., contains ``"molecule_id"``).
"""
_pack_init(st.sock, init_payload)
st.initialized = True
def _dispatch_field(self, st: _ClientState, efield_au: np.ndarray, meta: dict):
"""
Dispatch an EM field vector to a client via FIELDDATA/POSDATA.
Parameters
----------
st : _ClientState
Target client state.
efield_au : numpy.ndarray
Electric field vector ``(3,)`` in a.u.
meta : dict
Optional metadata to attach to this send.
Raises
------
_SocketClosed or OSError
If the client disconnects during send.
"""
try:
_send_msg(st.sock, FIELDDATA)
I = np.eye(3, dtype=DT_FLOAT)
_send_array(st.sock, I.T, DT_FLOAT)
_send_array(st.sock, I.T, DT_FLOAT)
_send_int(st.sock, 1)
vec = np.asarray(efield_au, dtype=DT_FLOAT).reshape(1, 3)
_send_array(st.sock, vec, DT_FLOAT)
st.pending_send = True
st.extras.update(meta or {})
except (socket.timeout, _SocketClosed, OSError):
st.alive = False
if st.molecule_id >= 0 and self.bound.get(st.molecule_id) is st:
self._log(
f"DISCONNECTED (send): mol {st.molecule_id} from {st.address}"
)
self.bound[st.molecule_id] = None
raise
def _query_result(self, st: _ClientState) -> Tuple[np.ndarray, bytes]:
"""
Request a client's source amplitude and read the READY payload.
Parameters
----------
st : _ClientState
Client state to query.
Returns
-------
tuple
``(amp_vec3, extra_bytes)`` where ``amp_vec3`` is a ``(3,)`` ndarray and
``extra_bytes`` carries auxiliary data.
Raises
------
_SocketClosed or OSError
If the client disconnects during the exchange.
"""
try:
_send_msg(st.sock, GETSOURCE)
msg = _recv_msg(st.sock)
if msg != SOURCEREADY:
raise _SocketClosed(f"Expected {SOURCEREADY!r}, got {msg!r}")
e, forces, vir, extra = _recv_getforce(st.sock)
amp = np.array(forces[0], dtype=float) # (3,)
st.last_amp = amp
st.pending_send = False
return amp, extra
except (socket.timeout, _SocketClosed, OSError):
st.alive = False
if st.molecule_id >= 0 and self.bound.get(st.molecule_id) is st:
self._log(
f"DISCONNECTED (recv): mol {st.molecule_id} from {st.address}"
)
self.bound[st.molecule_id] = None
raise
def _bind_client_locked(
self, st: _ClientState, molid: int, init_payload: dict, st_key
):
"""
Bind a client to a molecule ID if available and perform INIT.
Parameters
----------
st : _ClientState
Client to bind.
molid : int
Molecule ID to bind to.
init_payload : dict
INIT payload to send.
st_key : int
Temporary key under which the client is stored.
Returns
-------
bool
``True`` if binding succeeded, otherwise ``False``.
"""
if self.bound.get(molid) is None:
self._maybe_init_client(st, init_payload)
st.molecule_id = molid
self.bound[molid] = st
self.addrmap[st.address] = molid
self.clients[molid] = st
if st_key != molid:
try:
del self.clients[st_key]
except KeyError:
pass
address = st.address
self._log(f"CONNECTED: mol {molid} <- {address}")
# NEW: this molid is part of a frozen barrier -> force re-dispatch
self._reset_inflight_for(molid)
st.pending_send = False # defensive: this is a fresh socket
return True
return False
def _log(self, *a):
"""
Log a message with the ``[SocketHub]`` prefix.
"""
print("[SocketHub]", *a)
def _pause(self):
"""
Pause the hub (used when a driver disconnects mid-barrier).
"""
self.paused = True
def _resume(self):
"""
Resume the hub after a pause.
"""
self.paused = False
def _reset_inflight_for(self, molid: int):
"""
Force re-dispatch for ``molid`` in a frozen barrier after reconnect.
Parameters
----------
molid : int
Molecule ID to reset in the current barrier state.
"""
if self._inflight and (molid in self._inflight["wants"]):
self._inflight["sent"][molid] = False
self._inflight["ready"][molid] = False
def _find_free_molecule_id(self) -> int:
"""
Find and return an available molecule ID not already registered.
Returns
-------
int
A unique molecule ID.
"""
while True:
molecule_id = self._molecule_id_counter
self._molecule_id_counter += 1
if molecule_id not in self.expected:
return molecule_id
# -------------- public API --------------
[docs]
def register_molecule(self, molecule_id: int) -> None:
"""
Reserve a slot for a given molecule ID (client may connect later).
Parameters
----------
molecule_id : int
Molecule ID to register.
Raises
------
ValueError
If the molecule ID is already registered.
"""
with self._lock:
# If already registered, raising a ValueError
if molecule_id in self.expected:
raise ValueError(f"Molecule ID {molecule_id} already registered!")
# No explicit state needed yet; client binds on INIT.
self.expected.add(int(molecule_id))
self.bound.setdefault(int(molecule_id), None)
[docs]
def register_molecule_return_id(self) -> int:
"""
Reserve a slot for a molecule and return an auto-assigned ID.
Returns
-------
int
The assigned unique molecule ID.
"""
with self._lock:
# Find an available molecule_id
molecule_id = self._find_free_molecule_id()
self.register_molecule(molecule_id)
return molecule_id
[docs]
def step_barrier(
self, requests: Dict[int, dict], timeout: Optional[float] = None
) -> Dict[int, np.ndarray]:
"""
Barrier step: dispatch fields and collect source amplitudes from all clients.
Coordinates sending fields, waiting for results, and jointly committing the
results once every requested molecule is ready. A frozen barrier is reused if
a disconnect occurs mid-step.
Parameters
----------
requests : dict[int, dict]
Mapping from molecule ID to request dict with keys:
- ``"efield_au"`` : array-like ``(3,)`` field vector in a.u.
- ``"meta"`` : dict, optional metadata per send.
- ``"init"`` : dict, optional INIT payload for first bind.
timeout : float, optional
Maximum time (seconds) to wait for the barrier to complete. Defaults to the
hub's ``timeout`` setting.
Returns
-------
dict[int, dict]
Mapping ``molid -> {"amp": ndarray(3,), "extra": bytes}``. Returns ``{}``
when paused, on abort, or if the barrier is incomplete.
"""
if self.paused:
return {}
deadline = time.time() + (timeout or self.timeout)
results: Dict[int, dict] = {}
# If a barrier is already in flight, ignore new 'requests' and reuse the frozen one.
if self._inflight is None:
wants = set(requests.keys())
self._inflight = {
"wants": wants,
"efields": {
mid: np.asarray(requests[mid]["efield_au"], dtype=DT_FLOAT).copy()
for mid in wants
},
"meta": {mid: requests[mid].get("meta", {}) for mid in wants},
"sent": {mid: False for mid in wants},
"ready": {mid: False for mid in wants},
}
else:
# Reuse the frozen barrier even if the caller passed different fields
wants = set(self._inflight["wants"])
ready = self._inflight["ready"]
# --- hard gate: do not dispatch fields until everyone is bound ---
ids = set(int(k) for k in requests.keys())
with self._lock:
if not self.all_bound(ids, require_init=True):
# Try to progress INIT quickly, but DO NOT send FIELDDATA yet
for st_key, st in list(self.clients.items()):
if not st or not st.alive:
continue
try:
_send_msg(st.sock, STATUS)
reply = _recv_msg(st.sock)
except (socket.timeout, _SocketClosed, OSError):
st.alive = False
if st.molecule_id >= 0 and self.bound.get(st.molecule_id) is st:
self._log(
f"DISCONNECTED: mol {st.molecule_id} from {st.address}"
)
self.bound[st.molecule_id] = None
self._pause()
# NEW: make the frozen barrier re-send the old field to this molid
self._reset_inflight_for(st.molecule_id)
continue
if reply == NEEDINIT:
for mid in ids:
if self.bound.get(mid) is None:
init_payload = requests.get(mid, {}).get(
"init", {"molecule_id": mid}
)
self._bind_client_locked(
st, int(mid), init_payload, st_key
)
break
return {} # nothing dispatched; drivers remain idle
# --- normal step_barrier continues below ---
aborted = False
with self._lock:
# 1. poll and init/dispatch
for st_key, st in list(self.clients.items()):
try:
_send_msg(st.sock, STATUS)
reply = _recv_msg(st.sock)
except (socket.timeout, _SocketClosed, OSError):
st.alive = False
if st.molecule_id >= 0 and self.bound.get(st.molecule_id) is st:
self._log(
f"DISCONNECTED: mol {st.molecule_id} from {st.address}"
)
self.bound[st.molecule_id] = None
self._pause()
# NEW: make the frozen barrier re-send the old field to this molid
self._reset_inflight_for(st.molecule_id)
aborted = True
continue
if reply == NEEDINIT:
# Bind this client to the first UNBOUND molecule id present in requests
# Preserve the requests' order (Python dicts keep insertion order).
chosen = None
for mid in requests.keys():
if self.bound.get(mid) is None:
chosen = mid
break
# Fallback: allow a rebind if this client was previously bound and crashed/reconnected
if (
chosen is None
and st.molecule_id >= 0
and self.bound.get(st.molecule_id) is None
):
chosen = st.molecule_id
if chosen is None:
# Nothing to serve right now; keep the client idle
continue
init_payload = requests.get(chosen, {}).get(
"init", {"molecule_id": chosen}
)
self._bind_client_locked(st, int(chosen), init_payload, st_key)
elif reply == READY:
molid = st.molecule_id
if (
molid in self._inflight["wants"]
and not st.pending_send
and not self._inflight["sent"][molid]
):
evec = self._inflight["efields"][molid]
meta = self._inflight["meta"][molid]
try:
self._dispatch_field(st, evec, meta)
self._inflight["sent"][molid] = True
except Exception:
aborted = True
break
elif reply == HAVEDATA:
if st.molecule_id in ready:
ready[st.molecule_id] = True
continue
# 2. second pass: wait for all pending molecules to finish (barrier)
wants = set(self._inflight["wants"])
if aborted:
# abort this barrier; caller will enter the pause path
return {}
while time.time() < deadline and (wants - set(results.keys())):
time.sleep(self.latency)
with self._lock:
# Always sweep over ALL clients, not just molecule-ids in 'wants'
for st_key, st in list(self.clients.items()):
if not st or not st.alive:
continue
try:
_send_msg(st.sock, STATUS)
reply = _recv_msg(st.sock)
except (socket.timeout, _SocketClosed, OSError):
st.alive = False
if st.molecule_id >= 0 and self.bound.get(st.molecule_id) is st:
self._log(
f"DISCONNECTED: mol {st.molecule_id} from {st.address}"
)
self.bound[st.molecule_id] = None
self._pause()
# NEW: make the frozen barrier re-send the old field to this molid
self._reset_inflight_for(st.molecule_id)
continue
if reply == NEEDINIT:
# Bind this client to the first UNBOUND molecule id present in requests
# Preserve the requests' order (Python dicts keep insertion order).
chosen = None
for mid in requests.keys():
if self.bound.get(mid) is None:
chosen = mid
break
# Fallback: allow a rebind if this client was previously bound and crashed/reconnected
if (
chosen is None
and st.molecule_id >= 0
and self.bound.get(st.molecule_id) is None
):
chosen = st.molecule_id
if chosen is None:
# Nothing to serve right now; keep the client idle
continue
init_payload = requests.get(chosen, {}).get(
"init", {"molecule_id": chosen}
)
self._bind_client_locked(st, int(chosen), init_payload, st_key)
continue # next client
if reply == READY:
molid = st.molecule_id
if (
molid in self._inflight["wants"]
and not st.pending_send
and not self._inflight["sent"][molid]
):
evec = self._inflight["efields"][molid]
meta = self._inflight["meta"][molid]
self._dispatch_field(st, evec, meta)
continue
if reply == HAVEDATA:
if st.molecule_id in ready:
ready[st.molecule_id] = True
continue
# Some clients may erroneously send STATUS; ignore gracefully
if reply == STATUS:
continue
# exit condition: filled everything we wanted
if all(ready.get(mid, False) for mid in wants):
break
# Abort if not everyone is ready (e.g., disconnect); keep the frozen barrier.
if not all(
self._inflight["ready"].get(mid, False) for mid in self._inflight["wants"]
):
return {}
# Phase C: commit all together (send GETSOURCE to everyone now)
for mid in self._inflight["wants"]:
st = self.clients.get(mid)
if not st or not st.alive:
return {} # keep barrier for retry
try:
amp, extra = self._query_result(st)
results[mid] = {"amp": amp, "extra": extra}
except (socket.timeout, _SocketClosed, OSError):
return {} # keep barrier for retry
# SUCCESS — clear the frozen barrier
self._inflight = None
# Fallback: any missing results -> use last known amplitude (if any),
# We turn off this fallback for now to avoid silent errors.
if False:
with self._lock:
for mid in wants - set(results.keys()):
st = self.clients.get(mid)
if st and st.last_amp is not None:
results[mid] = {"amp": st.last_amp.copy(), "extra": b""}
return results
[docs]
def all_bound(self, molecule_ids, require_init=True):
"""
Check if all given molecule IDs are bound (and optionally initialized).
Parameters
----------
molecule_ids : iterable of int
Molecule IDs to check.
require_init : bool, default: True
Also require that clients completed INIT.
Returns
-------
bool
``True`` if all are bound (and initialized if requested), else ``False``.
"""
with self._lock:
for mid in molecule_ids:
st = self.bound.get(int(mid))
if st is None or not st.alive:
return False
if require_init and not st.initialized:
return False
return True
[docs]
def wait_until_bound(self, init_payloads: dict, require_init=True, timeout=None):
"""
Block until all requested molecule IDs are bound (and optionally initialized).
Parameters
----------
init_payloads : dict[int, dict]
Mapping from molecule ID to INIT payload to use on bind.
require_init : bool, default: True
Also require that clients completed INIT.
timeout : float or None, optional
Maximum time to wait (seconds). Uses hub default if ``None``.
Returns
-------
bool
``True`` if all requested IDs became bound within the time limit, else ``False``.
"""
wanted = {int(k) for k in init_payloads.keys()}
deadline = time.time() + (timeout if timeout is not None else self.timeout)
while True:
if self.all_bound(wanted, require_init=require_init):
self._resume()
return True
# Progress the INIT handshakes without dispatching any field steps
with self._lock:
# Only touch clients for molecule-ids that are still unbound
pending_ids = {mid for mid in wanted if self.bound.get(mid) is None}
if not pending_ids:
# (shouldn't happen because of the all_bound check, but be safe)
time.sleep(self.latency)
continue
for st_key, st in list(self.clients.items()):
if not st or not st.alive:
continue
# Only ping clients that are currently unbound (either brand-new or reconnecting)
if st.molecule_id >= 0 and st.molecule_id not in pending_ids:
continue
try:
_send_msg(st.sock, STATUS)
reply = _recv_msg(st.sock)
except (socket.timeout, _SocketClosed, OSError):
st.alive = False
# free a binding if this was a re-connect case
if st.molecule_id >= 0 and self.bound.get(st.molecule_id) is st:
self._log(
f"DISCONNECTED: mol {st.molecule_id} from {st.address}"
)
self.bound[st.molecule_id] = None
self._pause()
# NEW: make the frozen barrier re-send the old field to this molid
self._reset_inflight_for(st.molecule_id)
continue
if reply == NEEDINIT:
# choose an unclaimed id from 'wanted'
chosen = None
for mid in pending_ids:
if self.bound.get(mid) is None:
chosen = mid
break
if chosen is not None:
self._bind_client_locked(
st, int(chosen), init_payloads[int(chosen)], st_key
)
if timeout is not None and time.time() > deadline:
return False
time.sleep(self.latency)
[docs]
def graceful_shutdown(self, reason: Optional[str] = None, wait: float = 2.0):
"""
Politely ask all connected drivers to exit and wait briefly for ``BYE``.
Parameters
----------
reason : str or None, optional
Optional reason to log for shutdown.
wait : float, default: 2.0
Seconds to wait for clean replies.
"""
with self._lock:
for st in list(self.clients.values()):
if not st or not st.alive:
continue
try:
_send_msg(st.sock, STOP)
except Exception:
st.alive = False
if st.molecule_id >= 0 and self.bound.get(st.molecule_id) is st:
self._log(
f"DISCONNECTED: mol {st.molecule_id} from {st.address}"
)
self.bound[st.molecule_id] = None
self._pause()
deadline = time.time() + float(wait)
while time.time() < deadline:
time.sleep(self.latency)
with self._lock:
for st in list(self.clients.values()):
if not st or not st.alive:
continue
try:
# Make reads snappy during shutdown
st.sock.settimeout(self.latency)
msg = _recv_msg(st.sock)
if msg == BYE:
# Clean close on our side
st.alive = False
if (
st.molecule_id >= 0
and self.bound.get(st.molecule_id) is st
):
self._log(
f"DISCONNECTED: mol {st.molecule_id} from {st.address}"
)
self.bound[st.molecule_id] = None
try:
st.sock.shutdown(socket.SHUT_RDWR)
except Exception:
pass
try:
st.sock.close()
except Exception:
pass
except (socket.timeout, _SocketClosed, OSError):
# Either no message yet or peer closed already; keep sweeping
continue
[docs]
def stop(self):
"""
Stop accepting new connections, request clients to exit, and close sockets.
Also removes the UNIX socket path if one was created.
"""
# First, stop accepting new connections
self._stop = True
try:
self.serversock.close()
except Exception:
pass
# Then, gracefully end existing sessions
try:
self.graceful_shutdown(wait=max(2.0, 10 * self.latency))
finally:
with self._lock:
for st in list(self.clients.values()):
try:
st.sock.close()
except Exception:
pass
# if unix socket, remove the path
if self.unixsocket_path and os.path.exists(self.unixsocket_path):
os.unlink(self.unixsocket_path)
print(f"[SocketHub] Unlinked unix socket path {self.unixsocket_path}")