Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions cq/_core/dispatchers/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Awaitable, Callable
from typing import Protocol, Self, runtime_checkable

from cq._core.middleware import Middleware, MiddlewareGroup
from cq._core.middleware import Middleware, MiddlewareGroup, deliver_message


@runtime_checkable
Expand All @@ -29,17 +29,16 @@ def add_middlewares(self, *middlewares: Middleware[[I], O]) -> Self:
self.__middleware_group.add(*middlewares)
return self

async def _invoke_with_middlewares(
async def _deliver(
self,
handler: Callable[[I], Awaitable[O]],
message: I,
handler: Callable[[I], Awaitable[O]],
/,
fail_silently: bool = False,
) -> O:
try:
return await self.__middleware_group.invoke(handler, message)
except Exception:
if fail_silently:
return NotImplemented

raise
return await deliver_message(
message,
handler,
self.__middleware_group,
fail_silently,
)
10 changes: 3 additions & 7 deletions cq/_core/dispatchers/bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ async def dispatch(self, message: I, /) -> O:
self._trigger_listeners(message, task_group)

for handler in self._handlers_from(type(message)):
return await self._invoke_with_middlewares(
handler,
message,
handler.fail_silently,
)
return await self._deliver(message, handler, handler.fail_silently)

return NotImplemented

Expand All @@ -104,8 +100,8 @@ async def dispatch(self, message: I, /) -> None:

for handler in self._handlers_from(type(message)):
task_group.start_soon(
self._invoke_with_middlewares,
handler,
self._deliver,
message,
handler,
handler.fail_silently,
)
2 changes: 1 addition & 1 deletion cq/_core/dispatchers/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def add_static_step[T](
return self

async def dispatch(self, message: I, /) -> O:
return await self._invoke_with_middlewares(self.__steps.execute, message)
return await self._deliver(message, self.__steps.execute)


class ContextPipeline[I]:
Expand Down
15 changes: 15 additions & 0 deletions cq/_core/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,21 @@ async def __call__(
return value


async def deliver_message[I, O](
message: I,
handler: Callable[[I], Awaitable[O]],
middleware_group: MiddlewareGroup[[I], O],
fail_silently: bool = False,
) -> O:
try:
return await middleware_group.invoke(handler, message)
except Exception:
if fail_silently:
return NotImplemented

raise


def _is_gen_middleware[**P, T](
middleware: Middleware[P, T],
) -> TypeGuard[GeneratorMiddleware[P, T]]:
Expand Down
38 changes: 30 additions & 8 deletions cq/_core/pump.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Self

import anyio

from cq._core.middleware import Middleware, MiddlewareGroup, deliver_message
from cq._core.queues.abc import Consumer


Expand All @@ -13,19 +14,40 @@ class Pump[T]:
consumer: Consumer[T]
dispatcher: Callable[[T], Awaitable[Any]]
fail_silently: bool = field(default=False)
__middleware_group: MiddlewareGroup[[T], Any] = field(
default_factory=MiddlewareGroup,
init=False,
)

def add_middlewares(self, *middlewares: Middleware[[T], Any]) -> Self:
self.__middleware_group.add(*middlewares)
return self

async def drain(self) -> None:
async for message in self.consumer:
try:
await self.dispatcher(message)
except Exception:
if not self.fail_silently:
raise
await deliver_message(
message,
self.dispatcher,
self.__middleware_group,
self.fail_silently,
)

@asynccontextmanager
async def draining(self, /, *, graceful: bool = False) -> AsyncIterator[None]:
async def draining(
self,
/,
*,
concurrency: int | None = None,
graceful: bool = False,
) -> AsyncIterator[None]:
if concurrency is None:
concurrency = 1
elif concurrency < 1:
raise ValueError(f"`concurrency` must be at least 1, got {concurrency}.")

async with anyio.create_task_group() as task_group:
task_group.start_soon(self.drain)
for _ in range(concurrency):
task_group.start_soon(self.drain)

try:
yield
Expand Down
11 changes: 9 additions & 2 deletions cq/_core/queues/memory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections.abc import AsyncIterator, Awaitable, Callable
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
from contextlib import asynccontextmanager
from typing import Any, Self

import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from cq._core.middleware import Middleware
from cq._core.pump import Pump
from cq._core.queues.abc import Queue

Expand All @@ -30,9 +31,15 @@ async def draining(
dispatcher: Callable[[T], Awaitable[Any]],
/,
*,
concurrency: int | None = None,
fail_silently: bool = False,
middlewares: Sequence[Middleware[[T], Any]] = (),
) -> AsyncIterator[Self]:
async with Pump(self, dispatcher, fail_silently).draining(graceful=True):
async with (
Pump(self, dispatcher, fail_silently)
.add_middlewares(*middlewares)
.draining(concurrency=concurrency, graceful=True)
):
try:
yield self
finally:
Expand Down
4 changes: 2 additions & 2 deletions docs/guides/configuring.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class TimingMiddleware:
async def __call__(self, message):
start = time.time()
yield
self.metrics.record(time.time() - start)
await self.metrics.record(time.time() - start)

@dataclass
class ClassicTimingMiddleware:
Expand All @@ -104,7 +104,7 @@ class ClassicTimingMiddleware:
async def __call__(self, call_next, message):
start = time.time()
result = await call_next(message)
self.metrics.record(time.time() - start)
await self.metrics.record(time.time() - start)
return result
```

Expand Down
36 changes: 35 additions & 1 deletion docs/guides/queues.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,40 @@ Pump(queue, command_bus, fail_silently=True)

Use this when each message is independent and a failed dispatch should not take down the whole consumer. Logging or alerting on failures is the dispatcher's responsibility, typically through a middleware on the underlying bus.

### Concurrent consumption

A single drain loop pulls and dispatches messages sequentially, which means the pump naturally paces itself to the dispatcher's speed. Pass `concurrency=N` to `draining` to spawn `N` drain tasks against the same `Consumer`:

```python
async with Pump(queue, command_bus).draining(concurrency=4):
# ...
```

This requires the `Consumer` implementation to be safe for concurrent iteration. `MemoryQueue` is, since it wraps an `anyio` memory stream; if you write a custom `Consumer`, make sure its `__aiter__` can be consumed from several tasks at once.

Note that ordering is no longer guaranteed when `concurrency > 1`: faster dispatches will overtake slower ones even if the queue delivers messages in order.

### Middlewares

A `Pump` accepts its own middleware stack via `add_middlewares`, with the same syntax and conventions as a bus (see [Middlewares](configuring.md#middlewares)):

```python
async def sentry_middleware(message):
try:
yield
except Exception as exc:
sentry_sdk.capture_exception(exc)

pump = Pump(queue, command_bus).add_middlewares(sentry_middleware)
```

Pump-level middlewares wrap the consumption cycle, while bus-level middlewares still apply to every message dispatched. The distinction matters when deciding where to put a given concern:

* **On the pump**: anything tied to the consumption cycle, such as reporting failures that escape the dispatcher or capping concurrent dispatches when `concurrency > 1`.
* **On the bus**: anything tied to handling the message, such as input validation, business-level retries, or transactional boundaries.

Note that `fail_silently=True` also swallows exceptions raised from a pump middleware, so a middleware that intentionally aborts the pump will be suppressed in that mode.

## `MemoryQueue.draining` shortcut

`MemoryQueue` exposes a `draining` helper that combines `Pump` with the queue's own lifecycle. It opens a pump, yields the queue, and closes it on exit so the pump terminates gracefully without an explicit `close` call:
Expand All @@ -86,4 +120,4 @@ async def main(command_bus: CommandBus[Any]) -> None:
# Both commands have been dispatched here.
```

`fail_silently` is forwarded to the underlying `Pump`.
`concurrency`, `fail_silently`, and `middlewares` are forwarded to the underlying `Pump`.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
38 changes: 36 additions & 2 deletions tests/core/test_pump.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any

import anyio
import pytest

from cq import MemoryQueue, Pump
from cq import MemoryQueue, MiddlewareResult, Pump


class TestPump:
Expand All @@ -9,11 +12,42 @@ async def test_draining_without_graceful(self) -> None:

async def dispatcher(message: str) -> None:
await anyio.sleep(60)
dispatched.set()
dispatched.set() # pragma: no cover

queue = MemoryQueue[str]()

async with Pump(queue, dispatcher).draining(graceful=False):
await queue.send("message")

assert not dispatched.is_set()

async def test_draining_with_middleware(self) -> None:
entered = anyio.Event()
exited = anyio.Event()

async def middleware(message: str) -> MiddlewareResult[Any]:
entered.set()
try:
yield
finally:
exited.set()

async def dispatcher(message: str) -> None: ...

queue = MemoryQueue[str]()

async with Pump(queue, dispatcher).add_middlewares(middleware).draining():
await queue.send("message")

assert entered.is_set() and exited.is_set()

async def test_draining_when_concurrency_is_less_than_1_raise_value_error(
self,
) -> None:
async def dispatcher(message: str) -> None: ...

queue = MemoryQueue[str]()

with pytest.raises(ValueError):
async with Pump(queue, dispatcher).draining(concurrency=0):
raise NotImplementedError
6 changes: 3 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.