diff --git a/cq/_core/dispatchers/abc.py b/cq/_core/dispatchers/abc.py index 38386a1..7b6d194 100644 --- a/cq/_core/dispatchers/abc.py +++ b/cq/_core/dispatchers/abc.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable, Callable from typing import Protocol, Self, runtime_checkable -from cq._core.middleware import Middleware, MiddlewareGroup +from cq._core.middleware import Middleware, MiddlewareGroup, deliver_message @runtime_checkable @@ -29,17 +29,16 @@ def add_middlewares(self, *middlewares: Middleware[[I], O]) -> Self: self.__middleware_group.add(*middlewares) return self - async def _invoke_with_middlewares( + async def _deliver( self, - handler: Callable[[I], Awaitable[O]], message: I, + handler: Callable[[I], Awaitable[O]], /, fail_silently: bool = False, ) -> O: - try: - return await self.__middleware_group.invoke(handler, message) - except Exception: - if fail_silently: - return NotImplemented - - raise + return await deliver_message( + message, + handler, + self.__middleware_group, + fail_silently, + ) diff --git a/cq/_core/dispatchers/bus.py b/cq/_core/dispatchers/bus.py index 12fd9f2..03e1657 100644 --- a/cq/_core/dispatchers/bus.py +++ b/cq/_core/dispatchers/bus.py @@ -83,11 +83,7 @@ async def dispatch(self, message: I, /) -> O: self._trigger_listeners(message, task_group) for handler in self._handlers_from(type(message)): - return await self._invoke_with_middlewares( - handler, - message, - handler.fail_silently, - ) + return await self._deliver(message, handler, handler.fail_silently) return NotImplemented @@ -104,8 +100,8 @@ async def dispatch(self, message: I, /) -> None: for handler in self._handlers_from(type(message)): task_group.start_soon( - self._invoke_with_middlewares, - handler, + self._deliver, message, + handler, handler.fail_silently, ) diff --git a/cq/_core/dispatchers/pipe.py b/cq/_core/dispatchers/pipe.py index aa320ff..6f2180b 100644 --- a/cq/_core/dispatchers/pipe.py +++ b/cq/_core/dispatchers/pipe.py @@ -140,7 +140,7 @@ def add_static_step[T]( return self async def dispatch(self, message: I, /) -> O: - return await self._invoke_with_middlewares(self.__steps.execute, message) + return await self._deliver(message, self.__steps.execute) class ContextPipeline[I]: diff --git a/cq/_core/middleware.py b/cq/_core/middleware.py index 540f7f8..9271c15 100644 --- a/cq/_core/middleware.py +++ b/cq/_core/middleware.py @@ -115,6 +115,21 @@ async def __call__( return value +async def deliver_message[I, O]( + message: I, + handler: Callable[[I], Awaitable[O]], + middleware_group: MiddlewareGroup[[I], O], + fail_silently: bool = False, +) -> O: + try: + return await middleware_group.invoke(handler, message) + except Exception: + if fail_silently: + return NotImplemented + + raise + + def _is_gen_middleware[**P, T]( middleware: Middleware[P, T], ) -> TypeGuard[GeneratorMiddleware[P, T]]: diff --git a/cq/_core/pump.py b/cq/_core/pump.py index c0cb19d..18ca376 100644 --- a/cq/_core/pump.py +++ b/cq/_core/pump.py @@ -1,10 +1,11 @@ from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any +from typing import Any, Self import anyio +from cq._core.middleware import Middleware, MiddlewareGroup, deliver_message from cq._core.queues.abc import Consumer @@ -13,19 +14,40 @@ class Pump[T]: consumer: Consumer[T] dispatcher: Callable[[T], Awaitable[Any]] fail_silently: bool = field(default=False) + __middleware_group: MiddlewareGroup[[T], Any] = field( + default_factory=MiddlewareGroup, + init=False, + ) + + def add_middlewares(self, *middlewares: Middleware[[T], Any]) -> Self: + self.__middleware_group.add(*middlewares) + return self async def drain(self) -> None: async for message in self.consumer: - try: - await self.dispatcher(message) - except Exception: - if not self.fail_silently: - raise + await deliver_message( + message, + self.dispatcher, + self.__middleware_group, + self.fail_silently, + ) @asynccontextmanager - async def draining(self, /, *, graceful: bool = False) -> AsyncIterator[None]: + async def draining( + self, + /, + *, + concurrency: int | None = None, + graceful: bool = False, + ) -> AsyncIterator[None]: + if concurrency is None: + concurrency = 1 + elif concurrency < 1: + raise ValueError(f"`concurrency` must be at least 1, got {concurrency}.") + async with anyio.create_task_group() as task_group: - task_group.start_soon(self.drain) + for _ in range(concurrency): + task_group.start_soon(self.drain) try: yield diff --git a/cq/_core/queues/memory.py b/cq/_core/queues/memory.py index 6344652..cf6ec60 100644 --- a/cq/_core/queues/memory.py +++ b/cq/_core/queues/memory.py @@ -1,10 +1,11 @@ -from collections.abc import AsyncIterator, Awaitable, Callable +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence from contextlib import asynccontextmanager from typing import Any, Self import anyio from anyio.abc import ObjectReceiveStream, ObjectSendStream +from cq._core.middleware import Middleware from cq._core.pump import Pump from cq._core.queues.abc import Queue @@ -30,9 +31,15 @@ async def draining( dispatcher: Callable[[T], Awaitable[Any]], /, *, + concurrency: int | None = None, fail_silently: bool = False, + middlewares: Sequence[Middleware[[T], Any]] = (), ) -> AsyncIterator[Self]: - async with Pump(self, dispatcher, fail_silently).draining(graceful=True): + async with ( + Pump(self, dispatcher, fail_silently) + .add_middlewares(*middlewares) + .draining(concurrency=concurrency, graceful=True) + ): try: yield self finally: diff --git a/docs/guides/configuring.md b/docs/guides/configuring.md index 1b2584e..d012699 100644 --- a/docs/guides/configuring.md +++ b/docs/guides/configuring.md @@ -95,7 +95,7 @@ class TimingMiddleware: async def __call__(self, message): start = time.time() yield - self.metrics.record(time.time() - start) + await self.metrics.record(time.time() - start) @dataclass class ClassicTimingMiddleware: @@ -104,7 +104,7 @@ class ClassicTimingMiddleware: async def __call__(self, call_next, message): start = time.time() result = await call_next(message) - self.metrics.record(time.time() - start) + await self.metrics.record(time.time() - start) return result ``` diff --git a/docs/guides/queues.md b/docs/guides/queues.md index e54d329..d6b9263 100644 --- a/docs/guides/queues.md +++ b/docs/guides/queues.md @@ -68,6 +68,40 @@ Pump(queue, command_bus, fail_silently=True) Use this when each message is independent and a failed dispatch should not take down the whole consumer. Logging or alerting on failures is the dispatcher's responsibility, typically through a middleware on the underlying bus. +### Concurrent consumption + +A single drain loop pulls and dispatches messages sequentially, which means the pump naturally paces itself to the dispatcher's speed. Pass `concurrency=N` to `draining` to spawn `N` drain tasks against the same `Consumer`: + +```python +async with Pump(queue, command_bus).draining(concurrency=4): + # ... +``` + +This requires the `Consumer` implementation to be safe for concurrent iteration. `MemoryQueue` is, since it wraps an `anyio` memory stream; if you write a custom `Consumer`, make sure its `__aiter__` can be consumed from several tasks at once. + +Note that ordering is no longer guaranteed when `concurrency > 1`: faster dispatches will overtake slower ones even if the queue delivers messages in order. + +### Middlewares + +A `Pump` accepts its own middleware stack via `add_middlewares`, with the same syntax and conventions as a bus (see [Middlewares](configuring.md#middlewares)): + +```python +async def sentry_middleware(message): + try: + yield + except Exception as exc: + sentry_sdk.capture_exception(exc) + +pump = Pump(queue, command_bus).add_middlewares(sentry_middleware) +``` + +Pump-level middlewares wrap the consumption cycle, while bus-level middlewares still apply to every message dispatched. The distinction matters when deciding where to put a given concern: + +* **On the pump**: anything tied to the consumption cycle, such as reporting failures that escape the dispatcher or capping concurrent dispatches when `concurrency > 1`. +* **On the bus**: anything tied to handling the message, such as input validation, business-level retries, or transactional boundaries. + +Note that `fail_silently=True` also swallows exceptions raised from a pump middleware, so a middleware that intentionally aborts the pump will be suppressed in that mode. + ## `MemoryQueue.draining` shortcut `MemoryQueue` exposes a `draining` helper that combines `Pump` with the queue's own lifecycle. It opens a pump, yields the queue, and closes it on exit so the pump terminates gracefully without an explicit `close` call: @@ -86,4 +120,4 @@ async def main(command_bus: CommandBus[Any]) -> None: # Both commands have been dispatched here. ``` -`fail_silently` is forwarded to the underlying `Pump`. +`concurrency`, `fail_silently`, and `middlewares` are forwarded to the underlying `Pump`. diff --git a/tests/core/dispatcher/__init__.py b/tests/core/dispatchers/__init__.py similarity index 100% rename from tests/core/dispatcher/__init__.py rename to tests/core/dispatchers/__init__.py diff --git a/tests/core/dispatcher/test_bus.py b/tests/core/dispatchers/test_bus.py similarity index 100% rename from tests/core/dispatcher/test_bus.py rename to tests/core/dispatchers/test_bus.py diff --git a/tests/core/dispatcher/test_pipe.py b/tests/core/dispatchers/test_pipe.py similarity index 100% rename from tests/core/dispatcher/test_pipe.py rename to tests/core/dispatchers/test_pipe.py diff --git a/tests/core/queue/__init__.py b/tests/core/queues/__init__.py similarity index 100% rename from tests/core/queue/__init__.py rename to tests/core/queues/__init__.py diff --git a/tests/core/queue/test_memory.py b/tests/core/queues/test_memory.py similarity index 100% rename from tests/core/queue/test_memory.py rename to tests/core/queues/test_memory.py diff --git a/tests/core/test_pump.py b/tests/core/test_pump.py index 24ccd30..3a66c53 100644 --- a/tests/core/test_pump.py +++ b/tests/core/test_pump.py @@ -1,6 +1,9 @@ +from typing import Any + import anyio +import pytest -from cq import MemoryQueue, Pump +from cq import MemoryQueue, MiddlewareResult, Pump class TestPump: @@ -9,7 +12,7 @@ async def test_draining_without_graceful(self) -> None: async def dispatcher(message: str) -> None: await anyio.sleep(60) - dispatched.set() + dispatched.set() # pragma: no cover queue = MemoryQueue[str]() @@ -17,3 +20,34 @@ async def dispatcher(message: str) -> None: await queue.send("message") assert not dispatched.is_set() + + async def test_draining_with_middleware(self) -> None: + entered = anyio.Event() + exited = anyio.Event() + + async def middleware(message: str) -> MiddlewareResult[Any]: + entered.set() + try: + yield + finally: + exited.set() + + async def dispatcher(message: str) -> None: ... + + queue = MemoryQueue[str]() + + async with Pump(queue, dispatcher).add_middlewares(middleware).draining(): + await queue.send("message") + + assert entered.is_set() and exited.is_set() + + async def test_draining_when_concurrency_is_less_than_1_raise_value_error( + self, + ) -> None: + async def dispatcher(message: str) -> None: ... + + queue = MemoryQueue[str]() + + with pytest.raises(ValueError): + async with Pump(queue, dispatcher).draining(concurrency=0): + raise NotImplementedError diff --git a/uv.lock b/uv.lock index 6d54c6c..4d5d6fd 100644 --- a/uv.lock +++ b/uv.lock @@ -554,14 +554,14 @@ wheels = [ [[package]] name = "jaraco-functools" -version = "4.4.0" +version = "4.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "more-itertools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0f/27/056e0638a86749374d6f57d0b0db39f29509cce9313cf91bdc0ac4d91084/jaraco_functools-4.4.0.tar.gz", hash = "sha256:da21933b0417b89515562656547a77b4931f98176eb173644c0d35032a33d6bb", size = 19943, upload-time = "2025-12-21T09:29:43.6Z" } +sdist = { url = "https://files.pythonhosted.org/packages/36/cf/ea4ef2920830dea3f5ab2ea4da6fb67724e6dca80ee2553788c3607243d0/jaraco_functools-4.5.0.tar.gz", hash = "sha256:3bb5665ea4a020cf78a7040e89154c77edadb3ca74f366479669c5999aa70b03", size = 20272, upload-time = "2026-05-15T21:34:10.025Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fd/c4/813bb09f0985cb21e959f21f2464169eca882656849adf727ac7bb7e1767/jaraco_functools-4.4.0-py3-none-any.whl", hash = "sha256:9eec1e36f45c818d9bf307c8948eb03b2b56cd44087b3cdc989abca1f20b9176", size = 10481, upload-time = "2025-12-21T09:29:42.27Z" }, + { url = "https://files.pythonhosted.org/packages/96/9a/982e48afcffcd727a9144506720ffd4224b6b7e355c98641866f38b7c043/jaraco_functools-4.5.0-py3-none-any.whl", hash = "sha256:79ce39246eddbde4b3a03b77ea5f0f7878dc669b166a66cf3fa8e266aa3fa2f4", size = 10594, upload-time = "2026-05-15T21:34:08.595Z" }, ] [[package]]