-
-
Notifications
You must be signed in to change notification settings - Fork 6
feat(taskbroker-client): Add pass_headers option to task registration #623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
15615c2
6a018a5
4408594
a99382e
6e002f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,11 +2,12 @@ | |
|
|
||
| import base64 | ||
| import datetime | ||
| import inspect | ||
| import os | ||
| import time | ||
| from collections.abc import Callable, Collection, Mapping, MutableMapping | ||
| from functools import update_wrapper | ||
| from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar | ||
| from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, get_origin | ||
| from uuid import uuid4 | ||
|
|
||
| import msgpack | ||
|
|
@@ -54,6 +55,42 @@ def _get_parameters_format() -> ParametersFormat: | |
| R = TypeVar("R") | ||
|
|
||
|
|
||
| def assert_typed_kwarg( | ||
| func: Callable[..., Any], | ||
| param_name: str, | ||
| expected_types: tuple[type, ...], | ||
| context: str, | ||
| ) -> None: | ||
| """ | ||
| Validate that a function has a keyword argument with a compatible type annotation. | ||
|
|
||
| Raises TypeError if: | ||
| - The parameter does not exist | ||
| - The parameter is positional-only | ||
| - The parameter has a type annotation that is not in expected_types | ||
| """ | ||
| sig = inspect.signature(func, eval_str=True) | ||
| if param_name not in sig.parameters: | ||
| raise TypeError(f"{context}: function does not have a {param_name!r} parameter") | ||
|
|
||
| param = sig.parameters[param_name] | ||
| if param.kind == inspect.Parameter.POSITIONAL_ONLY: | ||
| raise TypeError( | ||
| f"{context}: {param_name!r} parameter is positional-only. " | ||
| f"It must be a keyword argument." | ||
| ) | ||
|
|
||
| if param.annotation is not inspect.Parameter.empty: | ||
| origin = get_origin(param.annotation) | ||
| if origin is None: | ||
| origin = param.annotation | ||
| if origin not in expected_types: | ||
| raise TypeError( | ||
| f"{context}: {param_name!r} parameter has type {param.annotation!r}. " | ||
| f"Expected one of: {', '.join(t.__name__ for t in expected_types)}." | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| ) | ||
|
cursor[bot] marked this conversation as resolved.
untitaker marked this conversation as resolved.
|
||
|
|
||
|
|
||
| class Task(Generic[P, R]): | ||
| def __init__( | ||
| self, | ||
|
|
@@ -68,6 +105,7 @@ def __init__( | |
| compression_type: CompressionType = CompressionType.PLAINTEXT, | ||
| report_timeout_errors: bool = True, | ||
| silenced_exceptions: tuple[type[BaseException], ...] | None = None, | ||
| pass_headers: bool = False, | ||
| ): | ||
| self.name = name | ||
| self._func = func | ||
|
|
@@ -88,6 +126,16 @@ def __init__( | |
| self.compression_type = compression_type | ||
| self.report_timeout_errors = report_timeout_errors | ||
| self.silenced_exceptions = silenced_exceptions or () | ||
| self.pass_headers = pass_headers | ||
|
|
||
| if pass_headers: | ||
| assert_typed_kwarg( | ||
| func, | ||
| "headers", | ||
| (dict, Mapping, MutableMapping, Any), | ||
| f"Task {name!r} with pass_headers=True", | ||
| ) | ||
|
|
||
| update_wrapper(self, func) | ||
|
|
||
| @property | ||
|
|
@@ -154,7 +202,15 @@ def apply_async( | |
|
|
||
| def _call_func(self, *args: Any, **kwargs: Any) -> None: | ||
| # Overridden in ExternalTask | ||
| self._func(*args, **kwargs) | ||
| if self.pass_headers: | ||
| if "headers" in kwargs: | ||
| raise TypeError( | ||
| f"Task '{self.name}' has pass_headers=True, but 'headers' was passed in kwargs. " | ||
| "The 'headers' parameter is injected by the worker and cannot be passed by the caller." | ||
| ) | ||
| self._func(*args, headers={}, **kwargs) # type: ignore[arg-type] | ||
|
untitaker marked this conversation as resolved.
|
||
| else: | ||
| self._func(*args, **kwargs) | ||
|
|
||
| def _signal_send(self, task: Task[Any, Any], args: Any, kwargs: Any) -> None: | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -465,3 +465,91 @@ def multi_task() -> None: | |
| activation = multi_task.create_activation([], {}) | ||
| assert activation.headers["x-test-context"] == "dispatched" | ||
| assert activation.headers["x-another"] == "also-here" | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test for the case when the user passes
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok. testing for a specific error doesn't make a lot of sense if there's no explicit errorhandling for it, so I added more checks to the code too. i think it's a bit bloated now, but probably more user-friendly. |
||
|
|
||
| def test_task_pass_headers_attribute(task_namespace: TaskNamespace) -> None: | ||
| """Tasks can opt into receiving headers via pass_headers=True.""" | ||
|
|
||
| @task_namespace.register(name="test.with_headers", pass_headers=True) | ||
| def with_headers(org_id: int, headers: dict[str, str]) -> None: | ||
| pass | ||
|
|
||
| assert with_headers.pass_headers is True | ||
|
|
||
| @task_namespace.register(name="test.without_headers") | ||
| def without_headers(org_id: int) -> None: | ||
| pass | ||
|
|
||
| assert without_headers.pass_headers is False | ||
|
|
||
|
|
||
| def test_pass_headers_requires_headers_parameter(task_namespace: TaskNamespace) -> None: | ||
| """Tasks with pass_headers=True must have a 'headers' parameter.""" | ||
| with pytest.raises(TypeError, match="does not have a 'headers' parameter"): | ||
|
|
||
| @task_namespace.register(name="test.missing_headers", pass_headers=True) | ||
| def missing_headers(org_id: int) -> None: | ||
| pass | ||
|
|
||
|
|
||
| def test_pass_headers_rejects_positional_only_headers(task_namespace: TaskNamespace) -> None: | ||
| """Tasks with pass_headers=True cannot have a positional-only 'headers' parameter.""" | ||
| with pytest.raises(TypeError, match="positional-only"): | ||
|
|
||
| @task_namespace.register(name="test.positional_headers", pass_headers=True) | ||
| def positional_headers(org_id: int, headers: dict[str, str], /) -> None: | ||
| pass | ||
|
|
||
|
|
||
| def test_pass_headers_rejects_incompatible_type_annotation( | ||
| task_namespace: TaskNamespace, | ||
| ) -> None: | ||
| """Tasks with pass_headers=True must have a dict-like type annotation for 'headers'.""" | ||
| with pytest.raises(TypeError, match="Expected one of: dict"): | ||
|
|
||
| @task_namespace.register(name="test.wrong_type_headers", pass_headers=True) | ||
| def wrong_type_headers(org_id: int, headers: str) -> None: | ||
| pass | ||
|
|
||
|
|
||
| def test_delay_immediate_mode_with_pass_headers(task_namespace: TaskNamespace) -> None: | ||
| """In ALWAYS_EAGER mode, tasks with pass_headers=True receive empty headers.""" | ||
| calls: list[dict[str, Any]] = [] | ||
|
|
||
| @task_namespace.register(name="test.headers_task", pass_headers=True) | ||
| def headers_task(value: str, headers: dict[str, str]) -> None: | ||
| calls.append({"value": value, "headers": headers}) | ||
|
|
||
| with patch("taskbroker_client.task.ALWAYS_EAGER", True): | ||
| headers_task.delay("test") # type: ignore[call-arg] | ||
|
|
||
| assert len(calls) == 1 | ||
| assert calls[0]["value"] == "test" | ||
| assert calls[0]["headers"] == {} | ||
|
|
||
|
|
||
| def test_pass_headers_rejects_headers_in_kwargs(task_namespace: TaskNamespace) -> None: | ||
| """Tasks with pass_headers=True reject 'headers' passed in kwargs.""" | ||
|
|
||
| @task_namespace.register(name="test.headers_collision", pass_headers=True) | ||
| def headers_task(value: str, headers: dict[str, str]) -> None: | ||
| pass | ||
|
|
||
| with patch("taskbroker_client.task.ALWAYS_EAGER", True): | ||
| with pytest.raises(TypeError, match="cannot be passed by the caller"): | ||
| headers_task.delay("test", headers={"x-custom": "value"}) | ||
|
|
||
|
|
||
| def test_delay_immediate_mode_without_pass_headers(task_namespace: TaskNamespace) -> None: | ||
| """In ALWAYS_EAGER mode, tasks without pass_headers do not receive headers kwarg.""" | ||
| calls: list[dict[str, Any]] = [] | ||
|
|
||
| @task_namespace.register(name="test.no_headers_task") | ||
| def no_headers_task(value: str) -> None: | ||
| calls.append({"value": value}) | ||
|
|
||
| with patch("taskbroker_client.task.ALWAYS_EAGER", True): | ||
| no_headers_task.delay("test") | ||
|
|
||
| assert len(calls) == 1 | ||
| assert calls[0] == {"value": "test"} | ||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're going to inject parameters into the task, should we validate that the task doesn't already have a
headersparameter with an incompatible type?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as in, you have a
pass_activationoption?this feels magical but i can do it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can enforce that the headers kwarg is there explicitly.
def mytask(**kwargs):is then just invalid.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was more thinking of the scenario where a task has an incompatible
headersparameter like:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but as I think about it more I like this idea less, and
pass_headersis the better solution.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think the current version covers this. it requires the header arg to be there, to have the right type, and to be explicit