from __future__ import annotations
import asyncio
from dataclasses import dataclass
[docs]
class QueueFullError(RuntimeError):
"""Raised when the pending-run admission queue is full."""
@dataclass(slots=True)
class _QueuedRun:
"""One pending run request waiting for admission."""
request_id: int
user_key: str
future: asyncio.Future
[docs]
class RunAdmissionController:
"""In-memory admission queue for global/per-user run concurrency limits."""
[docs]
def __init__(
self,
*,
global_limit: int,
per_user_limit: int,
max_queue_size: int = 0,
max_pending_per_user: int = 0,
) -> None:
self.global_limit = max(1, int(global_limit))
self.per_user_limit = max(1, int(per_user_limit))
self.max_queue_size = max(0, int(max_queue_size))
self.max_pending_per_user = max(0, int(max_pending_per_user))
self._lock = asyncio.Lock()
self._active_total = 0
self._active_by_user: dict[str, int] = {}
self._pending_by_user: dict[str, int] = {}
self._pending: list[_QueuedRun] = []
self._next_request_id = 1
@staticmethod
def _normalize_user_key(user_key: str | None) -> str:
"""Normalize user keys used by queue counters."""
cleaned = (user_key or "").strip().lower()
return cleaned or "anonymous"
def _can_run_locked(self, user_key: str) -> bool:
"""Check whether a user can start immediately under active counters."""
if self._active_total >= self.global_limit:
return False
return self._active_by_user.get(user_key, 0) < self.per_user_limit
def _grant_locked(
self, *, user_key: str, request_id: int, queued: bool
) -> dict[str, int | bool]:
"""Reserve one active slot and return grant metadata."""
self._active_total += 1
per_user_active = self._active_by_user.get(user_key, 0) + 1
self._active_by_user[user_key] = per_user_active
return {
"request_id": request_id,
"queued": queued,
"active_total": self._active_total,
"global_limit": self.global_limit,
"per_user_active": per_user_active,
"per_user_limit": self.per_user_limit,
"pending_total": len(self._pending),
}
def _increment_pending_user_locked(self, user_key: str) -> None:
"""Increment pending queue count for one user."""
self._pending_by_user[user_key] = self._pending_by_user.get(user_key, 0) + 1
def _decrement_pending_user_locked(self, user_key: str) -> None:
"""Decrement pending queue count for one user."""
current = self._pending_by_user.get(user_key, 0)
if current <= 1:
self._pending_by_user.pop(user_key, None)
return
self._pending_by_user[user_key] = current - 1
def _remove_pending_request_locked(self, request_id: int) -> bool:
"""Remove one queued request by id and update per-user counters."""
for idx, item in enumerate(self._pending):
if item.request_id != request_id:
continue
self._pending.pop(idx)
self._decrement_pending_user_locked(item.user_key)
return True
return False
def _dispatch_pending_locked(self) -> None:
"""Promote the first eligible queued runs while capacity exists."""
while self._active_total < self.global_limit and self._pending:
eligible_idx: int | None = None
for idx, pending in enumerate(self._pending):
if self._active_by_user.get(pending.user_key, 0) < self.per_user_limit:
eligible_idx = idx
break
if eligible_idx is None:
break
pending = self._pending.pop(eligible_idx)
self._decrement_pending_user_locked(pending.user_key)
if pending.future.cancelled():
continue
grant = self._grant_locked(
user_key=pending.user_key,
request_id=pending.request_id,
queued=True,
)
if not pending.future.done():
pending.future.set_result(grant)
[docs]
async def acquire(self, user_key: str | None) -> dict[str, int | bool]:
"""Acquire admission for a run, queueing when limits are saturated."""
normalized_user = self._normalize_user_key(user_key)
pending: _QueuedRun | None = None
async with self._lock:
if self._can_run_locked(normalized_user) and not self._pending:
request_id = self._next_request_id
self._next_request_id += 1
return self._grant_locked(
user_key=normalized_user,
request_id=request_id,
queued=False,
)
if (
self.max_pending_per_user > 0
and self._pending_by_user.get(normalized_user, 0)
>= self.max_pending_per_user
):
raise QueueFullError(
"Runner queue is full for this user "
f"({self.max_pending_per_user} pending requests)."
)
if self.max_queue_size > 0 and len(self._pending) >= self.max_queue_size:
raise QueueFullError(
f"Runner queue is full ({self.max_queue_size} pending requests)."
)
request_id = self._next_request_id
self._next_request_id += 1
pending = _QueuedRun(
request_id=request_id,
user_key=normalized_user,
future=asyncio.get_running_loop().create_future(),
)
self._pending.append(pending)
self._increment_pending_user_locked(normalized_user)
# Avoid head-of-line blocking when queue head is per-user limited.
self._dispatch_pending_locked()
if pending.future.done():
return pending.future.result()
try:
return await pending.future
except asyncio.CancelledError:
async with self._lock:
self._remove_pending_request_locked(pending.request_id)
self._dispatch_pending_locked()
raise
[docs]
async def release(self, user_key: str | None) -> None:
"""Release one active slot and admit queued work if possible."""
normalized_user = self._normalize_user_key(user_key)
async with self._lock:
current = self._active_by_user.get(normalized_user, 0)
if current <= 0:
return
if current == 1:
self._active_by_user.pop(normalized_user, None)
else:
self._active_by_user[normalized_user] = current - 1
if self._active_total > 0:
self._active_total -= 1
self._dispatch_pending_locked()
[docs]
async def snapshot(self) -> dict[str, int]:
"""Return current queue/accounting counters (for tests/diagnostics)."""
async with self._lock:
return {
"active_total": self._active_total,
"pending_total": len(self._pending),
"global_limit": self.global_limit,
"per_user_limit": self.per_user_limit,
}
[docs]
async def snapshot_for_user(self, user_key: str | None) -> dict[str, int | bool]:
"""Return current queue/accounting counters for one normalized user key.
Parameters
----------
user_key : str or None
User identifier to inspect.
Returns
-------
dict[str, int | bool]
Snapshot containing global/per-user counters and `can_run_now`.
"""
normalized_user = self._normalize_user_key(user_key)
async with self._lock:
per_user_active = self._active_by_user.get(normalized_user, 0)
per_user_pending = self._pending_by_user.get(normalized_user, 0)
can_run_now = self._can_run_locked(normalized_user) and not self._pending
return {
"user_key": normalized_user,
"can_run_now": can_run_now,
"per_user_active": per_user_active,
"per_user_pending": per_user_pending,
"active_total": self._active_total,
"pending_total": len(self._pending),
"global_limit": self.global_limit,
"per_user_limit": self.per_user_limit,
}