From 15615c27d36b2e6b65ebdcaa9efae2256a7d030e Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Thu, 7 May 2026 15:49:13 +0200 Subject: [PATCH 1/5] feat(taskbroker-client): Add pass_headers option to task registration Allow tasks to opt into receiving activation headers by setting pass_headers=True in the @register() decorator. When enabled, the task function receives headers as a keyword argument (dict[str, str]). This provides a simpler alternative to implementing a full ContextHook for tasks that just need direct access to header values. Co-Authored-By: Claude Opus 4.5 --- clients/python/src/examples/tasks.py | 9 ++++ .../python/src/taskbroker_client/registry.py | 7 +++ clients/python/src/taskbroker_client/task.py | 7 ++- clients/python/src/taskbroker_client/types.py | 3 ++ .../taskbroker_client/worker/workerchild.py | 5 +- clients/python/tests/test_task.py | 47 +++++++++++++++++++ clients/python/tests/worker/test_worker.py | 44 +++++++++++++++++ 7 files changed, 120 insertions(+), 2 deletions(-) diff --git a/clients/python/src/examples/tasks.py b/clients/python/src/examples/tasks.py index c7795dc8..2faa2b8b 100644 --- a/clients/python/src/examples/tasks.py +++ b/clients/python/src/examples/tasks.py @@ -109,3 +109,12 @@ def will_fail_with_silenced_ignored_exception() -> None: ) def will_retry_on_deadline_exceeded() -> None: timed_task(sleep_seconds=2) + + +@exampletasks.register(name="examples.task_with_headers", pass_headers=True) +def task_with_headers(value: str, headers: dict[str, str]) -> None: + redis = StrictRedis(host="localhost", port=6379, decode_responses=True) + redis.set("task-headers-value", value) + redis.set("task-headers-count", str(len(headers))) + if "x-custom-header" in headers: + redis.set("task-headers-custom", headers["x-custom-header"]) diff --git a/clients/python/src/taskbroker_client/registry.py b/clients/python/src/taskbroker_client/registry.py index d7d33da0..93aee24d 100644 --- a/clients/python/src/taskbroker_client/registry.py +++ b/clients/python/src/taskbroker_client/registry.py @@ -90,6 +90,7 @@ def register( compression_type: CompressionType = CompressionType.PLAINTEXT, report_timeout_errors: bool = True, silenced_exceptions: tuple[type[BaseException], ...] | None = None, + pass_headers: bool = False, ) -> Callable[[Callable[P, R]], Task[P, R]]: """ Register a task. @@ -121,6 +122,9 @@ def register( Enable reporting of ProcessingDeadlineExceededError to Sentry. silenced_exceptions: tuple[type[BaseException], ...] | None A tuple of exception types that will not be reported by Sentry. + pass_headers: bool + If True, the task function will receive task activation headers + as a keyword argument named `headers` (dict[str, str]). """ def wrapped(func: Callable[P, R]) -> Task[P, R]: @@ -141,6 +145,7 @@ def wrapped(func: Callable[P, R]) -> Task[P, R]: compression_type=compression_type, report_timeout_errors=report_timeout_errors, silenced_exceptions=silenced_exceptions, + pass_headers=pass_headers, ) # TODO(taskworker) tasks should be registered into the registry # so that we can ensure task names are globally unique @@ -234,6 +239,7 @@ def register( compression_type: CompressionType = CompressionType.PLAINTEXT, report_timeout_errors: bool = True, silenced_exceptions: tuple[type[BaseException], ...] | None = None, + pass_headers: bool = False, ) -> Callable[[Callable[P, R]], ExternalTask[P, R]]: """ Register an external task stub. @@ -285,6 +291,7 @@ def wrapped(func: Callable[P, R]) -> ExternalTask[P, R]: compression_type=compression_type, report_timeout_errors=report_timeout_errors, silenced_exceptions=silenced_exceptions, + pass_headers=pass_headers, ) self._registered_tasks[name] = task return task diff --git a/clients/python/src/taskbroker_client/task.py b/clients/python/src/taskbroker_client/task.py index f3ec77a2..7121e363 100644 --- a/clients/python/src/taskbroker_client/task.py +++ b/clients/python/src/taskbroker_client/task.py @@ -68,6 +68,7 @@ def __init__( compression_type: CompressionType = CompressionType.PLAINTEXT, report_timeout_errors: bool = True, silenced_exceptions: tuple[type[BaseException], ...] | None = None, + pass_headers: bool = False, ): self.name = name self._func = func @@ -88,6 +89,7 @@ def __init__( self.compression_type = compression_type self.report_timeout_errors = report_timeout_errors self.silenced_exceptions = silenced_exceptions or () + self.pass_headers = pass_headers update_wrapper(self, func) @property @@ -154,7 +156,10 @@ def apply_async( def _call_func(self, *args: Any, **kwargs: Any) -> None: # Overridden in ExternalTask - self._func(*args, **kwargs) + if self.pass_headers: + self._func(*args, headers={}, **kwargs) # type: ignore[arg-type] + else: + self._func(*args, **kwargs) def _signal_send(self, task: Task[Any, Any], args: Any, kwargs: Any) -> None: """ diff --git a/clients/python/src/taskbroker_client/types.py b/clients/python/src/taskbroker_client/types.py index c4a3783f..326ddd96 100644 --- a/clients/python/src/taskbroker_client/types.py +++ b/clients/python/src/taskbroker_client/types.py @@ -8,6 +8,9 @@ from arroyo.types import BrokerValue, Topic from sentry_protos.taskbroker.v1.taskbroker_pb2 import TaskActivation, TaskActivationStatus +TaskHeaders = dict[str, str] +"""Headers passed to a task function when pass_headers=True is set.""" + class ContextHook(Protocol): """ diff --git a/clients/python/src/taskbroker_client/worker/workerchild.py b/clients/python/src/taskbroker_client/worker/workerchild.py index 756292a5..aa0fcdea 100644 --- a/clients/python/src/taskbroker_client/worker/workerchild.py +++ b/clients/python/src/taskbroker_client/worker/workerchild.py @@ -402,7 +402,10 @@ def _execute_activation( ): for hook in context_hooks: stack.enter_context(hook.on_execute(headers)) - task_func(*args, **kwargs) + if task_func.pass_headers: + task_func(*args, headers=headers, **kwargs) + else: + task_func(*args, **kwargs) transaction.set_status(SPANSTATUS.OK) except Exception: transaction.set_status(SPANSTATUS.INTERNAL_ERROR) diff --git a/clients/python/tests/test_task.py b/clients/python/tests/test_task.py index 7fd9785e..a60158b8 100644 --- a/clients/python/tests/test_task.py +++ b/clients/python/tests/test_task.py @@ -465,3 +465,50 @@ def multi_task() -> None: activation = multi_task.create_activation([], {}) assert activation.headers["x-test-context"] == "dispatched" assert activation.headers["x-another"] == "also-here" + + +def test_task_pass_headers_attribute(task_namespace: TaskNamespace) -> None: + """Tasks can opt into receiving headers via pass_headers=True.""" + + @task_namespace.register(name="test.with_headers", pass_headers=True) + def with_headers(org_id: int, headers: dict[str, str]) -> None: + pass + + assert with_headers.pass_headers is True + + @task_namespace.register(name="test.without_headers") + def without_headers(org_id: int) -> None: + pass + + assert without_headers.pass_headers is False + + +def test_delay_immediate_mode_with_pass_headers(task_namespace: TaskNamespace) -> None: + """In ALWAYS_EAGER mode, tasks with pass_headers=True receive empty headers.""" + calls: list[dict[str, Any]] = [] + + @task_namespace.register(name="test.headers_task", pass_headers=True) + def headers_task(value: str, headers: dict[str, str]) -> None: + calls.append({"value": value, "headers": headers}) + + with patch("taskbroker_client.task.ALWAYS_EAGER", True): + headers_task.delay("test") # type: ignore[call-arg] + + assert len(calls) == 1 + assert calls[0]["value"] == "test" + assert calls[0]["headers"] == {} + + +def test_delay_immediate_mode_without_pass_headers(task_namespace: TaskNamespace) -> None: + """In ALWAYS_EAGER mode, tasks without pass_headers do not receive headers kwarg.""" + calls: list[dict[str, Any]] = [] + + @task_namespace.register(name="test.no_headers_task") + def no_headers_task(value: str) -> None: + calls.append({"value": value}) + + with patch("taskbroker_client.task.ALWAYS_EAGER", True): + no_headers_task.delay("test") + + assert len(calls) == 1 + assert calls[0] == {"value": "test"} diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index 904e4f18..f74ba490 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -269,6 +269,22 @@ ), ) +TASK_WITH_HEADERS = InflightTaskActivation( + host="localhost:50051", + receive_timestamp=0, + activation=TaskActivation( + id="headers_task_123", + taskname="examples.task_with_headers", + namespace="examples", + parameters_bytes=msgpack.packb({"args": ["test_value"], "kwargs": {}}, use_bin_type=True), + headers={ + "x-custom-header": "custom_value", + "sentry-trace": "trace-id", + }, + processing_deadline_duration=2, + ), +) + class TestTaskWorker(TestCase): def test_fetch_task(self) -> None: @@ -814,6 +830,34 @@ def test_child_process_record_checkin(mock_capture_checkin: mock.Mock) -> None: ) +def test_child_process_pass_headers() -> None: + """Task with pass_headers=True receives headers from the activation.""" + todo: queue.Queue[InflightTaskActivation] = queue.Queue() + processed: queue.Queue[ProcessingResult] = queue.Queue() + shutdown = Event() + + todo.put(TASK_WITH_HEADERS) + child_process( + "examples.app:app", + todo, + processed, + shutdown, + max_task_count=1, + processing_pool_name="test", + process_type="fork", + ) + + assert todo.empty() + result = processed.get() + assert result.task_id == TASK_WITH_HEADERS.activation.id + assert result.status == TASK_ACTIVATION_STATUS_COMPLETE + + redis = StrictRedis(host="localhost", port=6379, decode_responses=True) + assert redis.get("task-headers-value") == "test_value" + assert redis.get("task-headers-custom") == "custom_value" + redis.delete("task-headers-value", "task-headers-count", "task-headers-custom") + + @mock.patch("taskbroker_client.worker.workerchild.sentry_sdk.capture_exception") def test_child_process_terminate_task(mock_capture: mock.Mock) -> None: todo: queue.Queue[InflightTaskActivation] = queue.Queue() From 6a018a50fa537ab62fd93b48270d6a357f943321 Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Thu, 7 May 2026 18:16:47 +0200 Subject: [PATCH 2/5] feat(taskbroker-client): Validate headers parameter when pass_headers=True Use inspect module to check at task registration time that: - The function has a 'headers' parameter when pass_headers=True - The 'headers' parameter is not positional-only Co-Authored-By: Claude Opus 4.5 --- clients/python/src/taskbroker_client/task.py | 16 ++++++++++++++++ clients/python/tests/test_task.py | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/clients/python/src/taskbroker_client/task.py b/clients/python/src/taskbroker_client/task.py index 7121e363..8949089e 100644 --- a/clients/python/src/taskbroker_client/task.py +++ b/clients/python/src/taskbroker_client/task.py @@ -2,6 +2,7 @@ import base64 import datetime +import inspect import os import time from collections.abc import Callable, Collection, Mapping, MutableMapping @@ -90,6 +91,21 @@ def __init__( self.report_timeout_errors = report_timeout_errors self.silenced_exceptions = silenced_exceptions or () self.pass_headers = pass_headers + + if pass_headers: + sig = inspect.signature(func) + if "headers" not in sig.parameters: + raise TypeError( + f"Task {name!r} has pass_headers=True but the function " + f"does not have a 'headers' parameter" + ) + param = sig.parameters["headers"] + if param.kind == inspect.Parameter.POSITIONAL_ONLY: + raise TypeError( + f"Task {name!r} has pass_headers=True but the 'headers' parameter " + f"is positional-only. It must be a keyword argument." + ) + update_wrapper(self, func) @property diff --git a/clients/python/tests/test_task.py b/clients/python/tests/test_task.py index a60158b8..d514ecd8 100644 --- a/clients/python/tests/test_task.py +++ b/clients/python/tests/test_task.py @@ -483,6 +483,24 @@ def without_headers(org_id: int) -> None: assert without_headers.pass_headers is False +def test_pass_headers_requires_headers_parameter(task_namespace: TaskNamespace) -> None: + """Tasks with pass_headers=True must have a 'headers' parameter.""" + with pytest.raises(TypeError, match="does not have a 'headers' parameter"): + + @task_namespace.register(name="test.missing_headers", pass_headers=True) + def missing_headers(org_id: int) -> None: + pass + + +def test_pass_headers_rejects_positional_only_headers(task_namespace: TaskNamespace) -> None: + """Tasks with pass_headers=True cannot have a positional-only 'headers' parameter.""" + with pytest.raises(TypeError, match="positional-only"): + + @task_namespace.register(name="test.positional_headers", pass_headers=True) + def positional_headers(org_id: int, headers: dict[str, str], /) -> None: + pass + + def test_delay_immediate_mode_with_pass_headers(task_namespace: TaskNamespace) -> None: """In ALWAYS_EAGER mode, tasks with pass_headers=True receive empty headers.""" calls: list[dict[str, Any]] = [] From 4408594624519192e7d1feb1a6c0e75463b30e4c Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Thu, 7 May 2026 18:19:36 +0200 Subject: [PATCH 3/5] also check type hints, and move it to a separate function --- clients/python/src/taskbroker_client/task.py | 56 +++++++++++++++----- clients/python/tests/test_task.py | 11 ++++ 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/clients/python/src/taskbroker_client/task.py b/clients/python/src/taskbroker_client/task.py index 8949089e..a6938e6e 100644 --- a/clients/python/src/taskbroker_client/task.py +++ b/clients/python/src/taskbroker_client/task.py @@ -7,7 +7,7 @@ import time from collections.abc import Callable, Collection, Mapping, MutableMapping from functools import update_wrapper -from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, get_origin from uuid import uuid4 import msgpack @@ -55,6 +55,42 @@ def _get_parameters_format() -> ParametersFormat: R = TypeVar("R") +def assert_typed_kwarg( + func: Callable[..., Any], + param_name: str, + expected_types: tuple[type, ...], + context: str, +) -> None: + """ + Validate that a function has a keyword argument with a compatible type annotation. + + Raises TypeError if: + - The parameter does not exist + - The parameter is positional-only + - The parameter has a type annotation that is not in expected_types + """ + sig = inspect.signature(func) + if param_name not in sig.parameters: + raise TypeError(f"{context}: function does not have a {param_name!r} parameter") + + param = sig.parameters[param_name] + if param.kind == inspect.Parameter.POSITIONAL_ONLY: + raise TypeError( + f"{context}: {param_name!r} parameter is positional-only. " + f"It must be a keyword argument." + ) + + if param.annotation is not inspect.Parameter.empty: + origin = get_origin(param.annotation) + if origin is None: + origin = param.annotation + if origin not in expected_types: + raise TypeError( + f"{context}: {param_name!r} parameter has type {param.annotation!r}. " + f"Expected one of: {', '.join(t.__name__ for t in expected_types)}." + ) + + class Task(Generic[P, R]): def __init__( self, @@ -93,18 +129,12 @@ def __init__( self.pass_headers = pass_headers if pass_headers: - sig = inspect.signature(func) - if "headers" not in sig.parameters: - raise TypeError( - f"Task {name!r} has pass_headers=True but the function " - f"does not have a 'headers' parameter" - ) - param = sig.parameters["headers"] - if param.kind == inspect.Parameter.POSITIONAL_ONLY: - raise TypeError( - f"Task {name!r} has pass_headers=True but the 'headers' parameter " - f"is positional-only. It must be a keyword argument." - ) + assert_typed_kwarg( + func, + "headers", + (dict, Mapping, MutableMapping, Any), + f"Task {name!r} with pass_headers=True", + ) update_wrapper(self, func) diff --git a/clients/python/tests/test_task.py b/clients/python/tests/test_task.py index d514ecd8..8196041f 100644 --- a/clients/python/tests/test_task.py +++ b/clients/python/tests/test_task.py @@ -501,6 +501,17 @@ def positional_headers(org_id: int, headers: dict[str, str], /) -> None: pass +def test_pass_headers_rejects_incompatible_type_annotation( + task_namespace: TaskNamespace, +) -> None: + """Tasks with pass_headers=True must have a dict-like type annotation for 'headers'.""" + with pytest.raises(TypeError, match="Expected one of: dict"): + + @task_namespace.register(name="test.wrong_type_headers", pass_headers=True) + def wrong_type_headers(org_id: int, headers: str) -> None: + pass + + def test_delay_immediate_mode_with_pass_headers(task_namespace: TaskNamespace) -> None: """In ALWAYS_EAGER mode, tasks with pass_headers=True receive empty headers.""" calls: list[dict[str, Any]] = [] From a99382e59bd3adc1ef489ec845a4821ca6b9a94d Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Thu, 7 May 2026 19:16:46 +0200 Subject: [PATCH 4/5] fix --- clients/python/src/taskbroker_client/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/python/src/taskbroker_client/task.py b/clients/python/src/taskbroker_client/task.py index a6938e6e..89eab408 100644 --- a/clients/python/src/taskbroker_client/task.py +++ b/clients/python/src/taskbroker_client/task.py @@ -69,7 +69,7 @@ def assert_typed_kwarg( - The parameter is positional-only - The parameter has a type annotation that is not in expected_types """ - sig = inspect.signature(func) + sig = inspect.signature(func, eval_str=True) if param_name not in sig.parameters: raise TypeError(f"{context}: function does not have a {param_name!r} parameter") From 6e002f97adad57dcad5010e48cb354ed3ee6675c Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Fri, 8 May 2026 18:22:56 +0200 Subject: [PATCH 5/5] add testcase --- clients/python/src/taskbroker_client/task.py | 5 +++++ .../src/taskbroker_client/worker/workerchild.py | 5 +++++ clients/python/tests/test_task.py | 12 ++++++++++++ 3 files changed, 22 insertions(+) diff --git a/clients/python/src/taskbroker_client/task.py b/clients/python/src/taskbroker_client/task.py index 89eab408..3fb633d9 100644 --- a/clients/python/src/taskbroker_client/task.py +++ b/clients/python/src/taskbroker_client/task.py @@ -203,6 +203,11 @@ def apply_async( def _call_func(self, *args: Any, **kwargs: Any) -> None: # Overridden in ExternalTask if self.pass_headers: + if "headers" in kwargs: + raise TypeError( + f"Task '{self.name}' has pass_headers=True, but 'headers' was passed in kwargs. " + "The 'headers' parameter is injected by the worker and cannot be passed by the caller." + ) self._func(*args, headers={}, **kwargs) # type: ignore[arg-type] else: self._func(*args, **kwargs) diff --git a/clients/python/src/taskbroker_client/worker/workerchild.py b/clients/python/src/taskbroker_client/worker/workerchild.py index aa0fcdea..fa2e8476 100644 --- a/clients/python/src/taskbroker_client/worker/workerchild.py +++ b/clients/python/src/taskbroker_client/worker/workerchild.py @@ -403,6 +403,11 @@ def _execute_activation( for hook in context_hooks: stack.enter_context(hook.on_execute(headers)) if task_func.pass_headers: + if "headers" in kwargs: + raise TypeError( + f"Task '{task_func.name}' has pass_headers=True, but 'headers' was passed in kwargs. " + "The 'headers' parameter is injected by the worker and cannot be passed by the caller." + ) task_func(*args, headers=headers, **kwargs) else: task_func(*args, **kwargs) diff --git a/clients/python/tests/test_task.py b/clients/python/tests/test_task.py index 8196041f..2ea584d4 100644 --- a/clients/python/tests/test_task.py +++ b/clients/python/tests/test_task.py @@ -528,6 +528,18 @@ def headers_task(value: str, headers: dict[str, str]) -> None: assert calls[0]["headers"] == {} +def test_pass_headers_rejects_headers_in_kwargs(task_namespace: TaskNamespace) -> None: + """Tasks with pass_headers=True reject 'headers' passed in kwargs.""" + + @task_namespace.register(name="test.headers_collision", pass_headers=True) + def headers_task(value: str, headers: dict[str, str]) -> None: + pass + + with patch("taskbroker_client.task.ALWAYS_EAGER", True): + with pytest.raises(TypeError, match="cannot be passed by the caller"): + headers_task.delay("test", headers={"x-custom": "value"}) + + def test_delay_immediate_mode_without_pass_headers(task_namespace: TaskNamespace) -> None: """In ALWAYS_EAGER mode, tasks without pass_headers do not receive headers kwarg.""" calls: list[dict[str, Any]] = []