diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index 07d6f5afc16e9..eecc8056b14ba 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``1d6611b6ab7c`` (head) | ``888b59e02a5b`` | ``3.2.0`` | Add bundle_name to callback table. | +| ``a4c2d171ae18`` (head) | ``1d6611b6ab7c`` | ``3.3.0`` | Add dag_result to XComModel. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``1d6611b6ab7c`` | ``888b59e02a5b`` | ``3.2.0`` | Add bundle_name to callback table. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``888b59e02a5b`` | ``6222ce48e289`` | ``3.2.0`` | Fix migration file ORM inconsistencies. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index 9b83c40db5e30..69c0293b50359 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -345,6 +345,7 @@ def set_xcom( ), ] = None, map_index: Annotated[int, Query()] = -1, + dag_result: Annotated[bool, Query(description="Whether this XCom is a dag result")] = False, mapped_length: Annotated[ int | None, Query(description="Number of mapped tasks this value expands into") ] = None, @@ -397,6 +398,7 @@ def set_xcom( dag_id=dag_id, map_index=map_index, serialize=False, + dag_result=dag_result, session=session, ) except ValueError as e: diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 6bf54cdc1daa7..7522349c461d0 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -519,7 +519,14 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r resp = xcom elif isinstance(msg, SetXCom): self.client.xcoms.set( - msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length + msg.dag_id, + msg.run_id, + msg.task_id, + msg.key, + msg.value, + msg.map_index, + dag_result=msg.dag_result, + mapped_length=msg.mapped_length, ) elif isinstance(msg, GetDRCount): dr_count = self.client.dag_runs.get_count( diff --git a/airflow-core/src/airflow/migrations/versions/0110_3_3_0_xcom_dag_result.py b/airflow-core/src/airflow/migrations/versions/0110_3_3_0_xcom_dag_result.py new file mode 100644 index 0000000000000..6f852e2c9a189 --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0110_3_3_0_xcom_dag_result.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add dag_result to XComModel. + +Revision ID: a4c2d171ae18 +Revises: 1d6611b6ab7c +Create Date: 2026-03-17 00:23:45.305588 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a4c2d171ae18" +down_revision = "1d6611b6ab7c" +branch_labels = None +depends_on = None +airflow_version = "3.3.0" + + +def upgrade(): + """Add dag_result to XComModel.""" + with op.batch_alter_table("xcom", schema=None) as batch_op: + batch_op.add_column(sa.Column("dag_result", sa.Boolean, nullable=True)) + + +def downgrade(): + """Remove dag_result from XComModel.""" + with op.batch_alter_table("xcom", schema=None) as batch_op: + batch_op.drop_column("dag_result") diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index d474d521046c0..60d380ae989ce 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -25,6 +25,7 @@ from sqlalchemy import ( JSON, + Boolean, ForeignKeyConstraint, Index, Integer, @@ -66,6 +67,7 @@ class XComModel(TaskInstanceDependencies): task_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True) map_index: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, server_default="-1") key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) + dag_result: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False) # Denormalized for easier lookup. dag_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), nullable=False) @@ -163,6 +165,7 @@ def set( run_id: str, map_index: int = -1, serialize: bool = True, + dag_result: bool = False, session: Session = NEW_SESSION, ) -> None: """ @@ -241,6 +244,7 @@ def set( task_id=task_id, dag_id=dag_id, map_index=map_index, + dag_result=dag_result, ) session.add(new) session.flush() diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 9bc0608611b5a..a32644df803e6 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -116,6 +116,7 @@ class MappedClassProtocol(Protocol): "3.1.0": "cc92b33c6709", "3.1.8": "509b94a1042d", "3.2.0": "1d6611b6ab7c", + "3.3.0": "a4c2d171ae18", } # Prefix used to identify tables holding data moved during migration. diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py index 2135cb970a48b..eea40749270ba 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py @@ -560,6 +560,26 @@ def test_xcom_roundtrip(self, client, create_task_instance, session, value, expe assert response.status_code == 200 assert XComResponse.model_validate_json(response.read()).value == expected_value + def test_xcom_dag_result(self, client, create_task_instance, session): + """ + Test that the dag_result flag propagates to XComModel. + """ + ti = create_task_instance() + client.post( + f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/return_value", + params={"dag_result": True}, + json=123, + ) + + dag_result = session.scalar( + select(XComModel.dag_result).where( + XComModel.task_id == ti.task_id, + XComModel.dag_id == ti.dag_id, + XComModel.key == "return_value", + ) + ) + assert dag_result is True + class TestXComsDeleteEndpoint: def test_xcom_delete_endpoint(self, client, create_task_instance, session): diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 19e691281f73b..d0362ce12aa61 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -529,14 +529,18 @@ def set( key: str, value, map_index: int | None = None, + *, + dag_result: bool = False, mapped_length: int | None = None, ) -> OKResponse: """Set a XCom value via the API server.""" # TODO: check if we need to use map_index as params in the uri # ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81 - params = {} + params: dict[str, Any] = {} + if dag_result: + params["dag_result"] = dag_result if map_index is not None and map_index >= 0: - params = {"map_index": map_index} + params["map_index"] = map_index if mapped_length is not None and mapped_length >= 0: params["mapped_length"] = mapped_length self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value) @@ -554,9 +558,10 @@ def delete( map_index: int | None = None, ) -> OKResponse: """Delete a XCom with given key via the API server.""" - params = {} if map_index is not None and map_index >= 0: params = {"map_index": map_index} + else: + params = {} self.client.delete(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params) # Any error from the server will anyway be propagated down to the supervisor, # so we choose to send a generic response to the supervisor over the server response to diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 81cbfbed478e6..9667a608a1d0b 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -59,6 +59,7 @@ def set( task_id: str, run_id: str, map_index: int = -1, + dag_result: bool = False, _mapped_length: int | None = None, ) -> None: """ @@ -91,6 +92,7 @@ def set( task_id=task_id, run_id=run_id, map_index=map_index, + dag_result=dag_result, mapped_length=_mapped_length, ), ) diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index 00b811146a688..cc0fb44f4e73e 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -113,6 +113,7 @@ class AbstractOperator(Templater, DAGNode): _on_failure_fail_dagrun = False is_setup: bool = False is_teardown: bool = False + returns_dag_result: bool = False HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( ( diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index c3f786584746b..3df36692513a7 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -40,7 +40,7 @@ from dateutil.relativedelta import relativedelta from airflow import settings -from airflow.sdk import TaskInstanceState, TriggerRule +from airflow.sdk import TaskInstanceState, TriggerRule, XComArg from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.bases.timetable import BaseTimetable from airflow.sdk.definitions._internal.node import validate_key @@ -62,7 +62,7 @@ if TYPE_CHECKING: from re import Pattern - from typing import TypeAlias + from typing import TypeAlias, TypeVar from pendulum.tz.timezone import FixedTimezone, Timezone from typing_extensions import Self, TypeIs @@ -78,6 +78,8 @@ Operator: TypeAlias = BaseOperator | MappedOperator + X = TypeVar("X", bound=XComArg) + log = logging.getLogger(__name__) TAG_MAX_LEN = 100 @@ -1098,6 +1100,16 @@ def _remove_task(self, task_id: str) -> None: if tg: tg._remove(task) + def add_result(self, xcom_arg: X) -> X: + from airflow.sdk.bases.xcom import BaseXCom + from airflow.sdk.definitions.xcom_arg import PlainXComArg + + if not isinstance(xcom_arg, PlainXComArg) or xcom_arg.key != BaseXCom.XCOM_RETURN_KEY: + raise ValueError("Only plain return value can be used as dag result") + + xcom_arg.operator.returns_dag_result = True + return xcom_arg + def check_cycle(self) -> None: """ Check to see if there are any cycles in the Dag. diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py index 7a4d3125e5fd8..d3d99951609a3 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from collections.abc import Callable +from typing import TYPE_CHECKING from airflow.sdk.bases.decorator import TaskDecorator from airflow.sdk.definitions.dag import dag @@ -25,6 +25,9 @@ from airflow.sdk.definitions.decorators.task_group import task_group from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime +if TYPE_CHECKING: + from collections.abc import Callable + # Please keep this in sync with the .pyi's __all__. __all__ = [ "TaskDecorator", diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 2a9a9bbd4eb02..a47554249775d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -838,6 +838,7 @@ class SetXCom(BaseModel): run_id: str task_id: str map_index: int | None = None + dag_result: bool = False mapped_length: int | None = None type: Literal["SetXCom"] = "SetXCom" diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 1dfefee54047c..a4386927fee0c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1339,7 +1339,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: self.client.task_instances.skip_downstream_tasks(self.id, msg) elif isinstance(msg, SetXCom): self.client.xcoms.set( - msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length + msg.dag_id, + msg.run_id, + msg.task_id, + msg.key, + msg.value, + msg.map_index, + dag_result=msg.dag_result, + mapped_length=msg.mapped_length, ) elif isinstance(msg, DeleteXCom): self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index b8fa9377b1616..649c076126c65 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -703,7 +703,13 @@ def mark_success_url(self) -> str: return self.log_url -def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None: +def _xcom_push( + ti: RuntimeTaskInstance, + key: str, + value: Any, + *, + mapped_length: int | None = None, +) -> None: """Push a XCom through XCom.set, which pushes to XCom Backend if configured.""" # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK # consumers @@ -715,6 +721,7 @@ def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index, + dag_result=ti.task.returns_dag_result, _mapped_length=mapped_length, ) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index b486ce7776611..e88b5c794cfda 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -1579,8 +1579,8 @@ class RequestTestCase: "test_key", '{"key": "test_key", "value": {"key2": "value2"}}', None, - None, ), + kwargs={"dag_result": False, "mapped_length": None}, response=OKResponse(ok=True), ), test_id="set_xcom", @@ -1603,8 +1603,8 @@ class RequestTestCase: "test_key", '{"key": "test_key", "value": {"key2": "value2"}}', 2, - None, ), + kwargs={"dag_result": False, "mapped_length": None}, response=OKResponse(ok=True), ), test_id="set_xcom_with_map_index", @@ -1628,12 +1628,36 @@ class RequestTestCase: "test_key", '{"key": "test_key", "value": {"key2": "value2"}}', 2, - 3, ), + kwargs={"dag_result": False, "mapped_length": 3}, response=OKResponse(ok=True), ), test_id="set_xcom_with_map_index_and_mapped_length", ), + RequestTestCase( + message=SetXCom( + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + key="test_key", + value='{"key": "test_key", "value": {"key2": "value2"}}', + dag_result=True, + ), + client_mock=ClientMock( + method_path="xcoms.set", + args=( + "test_dag", + "test_run", + "test_task", + "test_key", + '{"key": "test_key", "value": {"key2": "value2"}}', + None, + ), + kwargs={"dag_result": True, "mapped_length": None}, + response=OKResponse(ok=True), + ), + test_id="set_xcom_with_dag_result", + ), RequestTestCase( message=DeleteXCom( dag_id="test_dag", diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index ab3bec4707525..5aeb009bd33da 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -3082,7 +3082,7 @@ def execute(self, context): runtime_ti = create_runtime_ti(task=task) with mock.patch.object(XCom, "set") as mock_xcom_set: - _xcom_push(runtime_ti, BaseXCom.XCOM_RETURN_KEY, result, 7) + _xcom_push(runtime_ti, BaseXCom.XCOM_RETURN_KEY, result, mapped_length=7) mock_xcom_set.assert_called_once_with( key=BaseXCom.XCOM_RETURN_KEY, value=result, @@ -3090,6 +3090,7 @@ def execute(self, context): task_id=runtime_ti.task_id, run_id=runtime_ti.run_id, map_index=runtime_ti.map_index, + dag_result=False, _mapped_length=7, ) @@ -3158,6 +3159,7 @@ def execute(self, context): task_id="pull_task", run_id="test_run", map_index=-1, + dag_result=False, _mapped_length=None, ) @@ -4718,3 +4720,24 @@ def test_operator_failures_metrics_emitted(self, create_runtime_ti, mock_supervi tags={**stats_tags, "operator": "PythonOperator"}, ) mock_stats.incr.assert_any_call("ti_failures", tags=stats_tags) + + +def test_dag_add_result(create_runtime_ti, mock_supervisor_comms): + with DAG(dag_id="test_dag_add_result") as dag: + task = PythonOperator(task_id="t", python_callable=lambda: 123) + dag.add_result(task.output) + + ti = create_runtime_ti(task=task) + run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + + mock_supervisor_comms.send.assert_any_call( + SetXCom( + key="return_value", + value=123, + dag_id="test_dag_add_result", + run_id="test_run", + task_id="t", + map_index=-1, + dag_result=True, + ) + )