diff --git a/clients/python/src/examples/tasks.py b/clients/python/src/examples/tasks.py index c7795dc8..2faa2b8b 100644 --- a/clients/python/src/examples/tasks.py +++ b/clients/python/src/examples/tasks.py @@ -109,3 +109,12 @@ def will_fail_with_silenced_ignored_exception() -> None: ) def will_retry_on_deadline_exceeded() -> None: timed_task(sleep_seconds=2) + + +@exampletasks.register(name="examples.task_with_headers", pass_headers=True) +def task_with_headers(value: str, headers: dict[str, str]) -> None: + redis = StrictRedis(host="localhost", port=6379, decode_responses=True) + redis.set("task-headers-value", value) + redis.set("task-headers-count", str(len(headers))) + if "x-custom-header" in headers: + redis.set("task-headers-custom", headers["x-custom-header"]) diff --git a/clients/python/src/taskbroker_client/registry.py b/clients/python/src/taskbroker_client/registry.py index d7d33da0..93aee24d 100644 --- a/clients/python/src/taskbroker_client/registry.py +++ b/clients/python/src/taskbroker_client/registry.py @@ -90,6 +90,7 @@ def register( compression_type: CompressionType = CompressionType.PLAINTEXT, report_timeout_errors: bool = True, silenced_exceptions: tuple[type[BaseException], ...] | None = None, + pass_headers: bool = False, ) -> Callable[[Callable[P, R]], Task[P, R]]: """ Register a task. @@ -121,6 +122,9 @@ def register( Enable reporting of ProcessingDeadlineExceededError to Sentry. silenced_exceptions: tuple[type[BaseException], ...] | None A tuple of exception types that will not be reported by Sentry. + pass_headers: bool + If True, the task function will receive task activation headers + as a keyword argument named `headers` (dict[str, str]). """ def wrapped(func: Callable[P, R]) -> Task[P, R]: @@ -141,6 +145,7 @@ def wrapped(func: Callable[P, R]) -> Task[P, R]: compression_type=compression_type, report_timeout_errors=report_timeout_errors, silenced_exceptions=silenced_exceptions, + pass_headers=pass_headers, ) # TODO(taskworker) tasks should be registered into the registry # so that we can ensure task names are globally unique @@ -234,6 +239,7 @@ def register( compression_type: CompressionType = CompressionType.PLAINTEXT, report_timeout_errors: bool = True, silenced_exceptions: tuple[type[BaseException], ...] | None = None, + pass_headers: bool = False, ) -> Callable[[Callable[P, R]], ExternalTask[P, R]]: """ Register an external task stub. @@ -285,6 +291,7 @@ def wrapped(func: Callable[P, R]) -> ExternalTask[P, R]: compression_type=compression_type, report_timeout_errors=report_timeout_errors, silenced_exceptions=silenced_exceptions, + pass_headers=pass_headers, ) self._registered_tasks[name] = task return task diff --git a/clients/python/src/taskbroker_client/task.py b/clients/python/src/taskbroker_client/task.py index f3ec77a2..3fb633d9 100644 --- a/clients/python/src/taskbroker_client/task.py +++ b/clients/python/src/taskbroker_client/task.py @@ -2,11 +2,12 @@ import base64 import datetime +import inspect import os import time from collections.abc import Callable, Collection, Mapping, MutableMapping from functools import update_wrapper -from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, get_origin from uuid import uuid4 import msgpack @@ -54,6 +55,42 @@ def _get_parameters_format() -> ParametersFormat: R = TypeVar("R") +def assert_typed_kwarg( + func: Callable[..., Any], + param_name: str, + expected_types: tuple[type, ...], + context: str, +) -> None: + """ + Validate that a function has a keyword argument with a compatible type annotation. + + Raises TypeError if: + - The parameter does not exist + - The parameter is positional-only + - The parameter has a type annotation that is not in expected_types + """ + sig = inspect.signature(func, eval_str=True) + if param_name not in sig.parameters: + raise TypeError(f"{context}: function does not have a {param_name!r} parameter") + + param = sig.parameters[param_name] + if param.kind == inspect.Parameter.POSITIONAL_ONLY: + raise TypeError( + f"{context}: {param_name!r} parameter is positional-only. " + f"It must be a keyword argument." + ) + + if param.annotation is not inspect.Parameter.empty: + origin = get_origin(param.annotation) + if origin is None: + origin = param.annotation + if origin not in expected_types: + raise TypeError( + f"{context}: {param_name!r} parameter has type {param.annotation!r}. " + f"Expected one of: {', '.join(t.__name__ for t in expected_types)}." + ) + + class Task(Generic[P, R]): def __init__( self, @@ -68,6 +105,7 @@ def __init__( compression_type: CompressionType = CompressionType.PLAINTEXT, report_timeout_errors: bool = True, silenced_exceptions: tuple[type[BaseException], ...] | None = None, + pass_headers: bool = False, ): self.name = name self._func = func @@ -88,6 +126,16 @@ def __init__( self.compression_type = compression_type self.report_timeout_errors = report_timeout_errors self.silenced_exceptions = silenced_exceptions or () + self.pass_headers = pass_headers + + if pass_headers: + assert_typed_kwarg( + func, + "headers", + (dict, Mapping, MutableMapping, Any), + f"Task {name!r} with pass_headers=True", + ) + update_wrapper(self, func) @property @@ -154,7 +202,15 @@ def apply_async( def _call_func(self, *args: Any, **kwargs: Any) -> None: # Overridden in ExternalTask - self._func(*args, **kwargs) + if self.pass_headers: + if "headers" in kwargs: + raise TypeError( + f"Task '{self.name}' has pass_headers=True, but 'headers' was passed in kwargs. " + "The 'headers' parameter is injected by the worker and cannot be passed by the caller." + ) + self._func(*args, headers={}, **kwargs) # type: ignore[arg-type] + else: + self._func(*args, **kwargs) def _signal_send(self, task: Task[Any, Any], args: Any, kwargs: Any) -> None: """ diff --git a/clients/python/src/taskbroker_client/types.py b/clients/python/src/taskbroker_client/types.py index c4a3783f..326ddd96 100644 --- a/clients/python/src/taskbroker_client/types.py +++ b/clients/python/src/taskbroker_client/types.py @@ -8,6 +8,9 @@ from arroyo.types import BrokerValue, Topic from sentry_protos.taskbroker.v1.taskbroker_pb2 import TaskActivation, TaskActivationStatus +TaskHeaders = dict[str, str] +"""Headers passed to a task function when pass_headers=True is set.""" + class ContextHook(Protocol): """ diff --git a/clients/python/src/taskbroker_client/worker/workerchild.py b/clients/python/src/taskbroker_client/worker/workerchild.py index 756292a5..fa2e8476 100644 --- a/clients/python/src/taskbroker_client/worker/workerchild.py +++ b/clients/python/src/taskbroker_client/worker/workerchild.py @@ -402,7 +402,15 @@ def _execute_activation( ): for hook in context_hooks: stack.enter_context(hook.on_execute(headers)) - task_func(*args, **kwargs) + if task_func.pass_headers: + if "headers" in kwargs: + raise TypeError( + f"Task '{task_func.name}' has pass_headers=True, but 'headers' was passed in kwargs. " + "The 'headers' parameter is injected by the worker and cannot be passed by the caller." + ) + task_func(*args, headers=headers, **kwargs) + else: + task_func(*args, **kwargs) transaction.set_status(SPANSTATUS.OK) except Exception: transaction.set_status(SPANSTATUS.INTERNAL_ERROR) diff --git a/clients/python/tests/test_task.py b/clients/python/tests/test_task.py index 7fd9785e..2ea584d4 100644 --- a/clients/python/tests/test_task.py +++ b/clients/python/tests/test_task.py @@ -465,3 +465,91 @@ def multi_task() -> None: activation = multi_task.create_activation([], {}) assert activation.headers["x-test-context"] == "dispatched" assert activation.headers["x-another"] == "also-here" + + +def test_task_pass_headers_attribute(task_namespace: TaskNamespace) -> None: + """Tasks can opt into receiving headers via pass_headers=True.""" + + @task_namespace.register(name="test.with_headers", pass_headers=True) + def with_headers(org_id: int, headers: dict[str, str]) -> None: + pass + + assert with_headers.pass_headers is True + + @task_namespace.register(name="test.without_headers") + def without_headers(org_id: int) -> None: + pass + + assert without_headers.pass_headers is False + + +def test_pass_headers_requires_headers_parameter(task_namespace: TaskNamespace) -> None: + """Tasks with pass_headers=True must have a 'headers' parameter.""" + with pytest.raises(TypeError, match="does not have a 'headers' parameter"): + + @task_namespace.register(name="test.missing_headers", pass_headers=True) + def missing_headers(org_id: int) -> None: + pass + + +def test_pass_headers_rejects_positional_only_headers(task_namespace: TaskNamespace) -> None: + """Tasks with pass_headers=True cannot have a positional-only 'headers' parameter.""" + with pytest.raises(TypeError, match="positional-only"): + + @task_namespace.register(name="test.positional_headers", pass_headers=True) + def positional_headers(org_id: int, headers: dict[str, str], /) -> None: + pass + + +def test_pass_headers_rejects_incompatible_type_annotation( + task_namespace: TaskNamespace, +) -> None: + """Tasks with pass_headers=True must have a dict-like type annotation for 'headers'.""" + with pytest.raises(TypeError, match="Expected one of: dict"): + + @task_namespace.register(name="test.wrong_type_headers", pass_headers=True) + def wrong_type_headers(org_id: int, headers: str) -> None: + pass + + +def test_delay_immediate_mode_with_pass_headers(task_namespace: TaskNamespace) -> None: + """In ALWAYS_EAGER mode, tasks with pass_headers=True receive empty headers.""" + calls: list[dict[str, Any]] = [] + + @task_namespace.register(name="test.headers_task", pass_headers=True) + def headers_task(value: str, headers: dict[str, str]) -> None: + calls.append({"value": value, "headers": headers}) + + with patch("taskbroker_client.task.ALWAYS_EAGER", True): + headers_task.delay("test") # type: ignore[call-arg] + + assert len(calls) == 1 + assert calls[0]["value"] == "test" + assert calls[0]["headers"] == {} + + +def test_pass_headers_rejects_headers_in_kwargs(task_namespace: TaskNamespace) -> None: + """Tasks with pass_headers=True reject 'headers' passed in kwargs.""" + + @task_namespace.register(name="test.headers_collision", pass_headers=True) + def headers_task(value: str, headers: dict[str, str]) -> None: + pass + + with patch("taskbroker_client.task.ALWAYS_EAGER", True): + with pytest.raises(TypeError, match="cannot be passed by the caller"): + headers_task.delay("test", headers={"x-custom": "value"}) + + +def test_delay_immediate_mode_without_pass_headers(task_namespace: TaskNamespace) -> None: + """In ALWAYS_EAGER mode, tasks without pass_headers do not receive headers kwarg.""" + calls: list[dict[str, Any]] = [] + + @task_namespace.register(name="test.no_headers_task") + def no_headers_task(value: str) -> None: + calls.append({"value": value}) + + with patch("taskbroker_client.task.ALWAYS_EAGER", True): + no_headers_task.delay("test") + + assert len(calls) == 1 + assert calls[0] == {"value": "test"} diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index 904e4f18..f74ba490 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -269,6 +269,22 @@ ), ) +TASK_WITH_HEADERS = InflightTaskActivation( + host="localhost:50051", + receive_timestamp=0, + activation=TaskActivation( + id="headers_task_123", + taskname="examples.task_with_headers", + namespace="examples", + parameters_bytes=msgpack.packb({"args": ["test_value"], "kwargs": {}}, use_bin_type=True), + headers={ + "x-custom-header": "custom_value", + "sentry-trace": "trace-id", + }, + processing_deadline_duration=2, + ), +) + class TestTaskWorker(TestCase): def test_fetch_task(self) -> None: @@ -814,6 +830,34 @@ def test_child_process_record_checkin(mock_capture_checkin: mock.Mock) -> None: ) +def test_child_process_pass_headers() -> None: + """Task with pass_headers=True receives headers from the activation.""" + todo: queue.Queue[InflightTaskActivation] = queue.Queue() + processed: queue.Queue[ProcessingResult] = queue.Queue() + shutdown = Event() + + todo.put(TASK_WITH_HEADERS) + child_process( + "examples.app:app", + todo, + processed, + shutdown, + max_task_count=1, + processing_pool_name="test", + process_type="fork", + ) + + assert todo.empty() + result = processed.get() + assert result.task_id == TASK_WITH_HEADERS.activation.id + assert result.status == TASK_ACTIVATION_STATUS_COMPLETE + + redis = StrictRedis(host="localhost", port=6379, decode_responses=True) + assert redis.get("task-headers-value") == "test_value" + assert redis.get("task-headers-custom") == "custom_value" + redis.delete("task-headers-value", "task-headers-count", "task-headers-custom") + + @mock.patch("taskbroker_client.worker.workerchild.sentry_sdk.capture_exception") def test_child_process_terminate_task(mock_capture: mock.Mock) -> None: todo: queue.Queue[InflightTaskActivation] = queue.Queue()