Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow-core/docs/migrations-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+=========================+==================+===================+==============================================================+
| ``1d6611b6ab7c`` (head) | ``888b59e02a5b`` | ``3.2.0`` | Add bundle_name to callback table. |
| ``a4c2d171ae18`` (head) | ``1d6611b6ab7c`` | ``3.3.0`` | Add dag_result to XComModel. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``1d6611b6ab7c`` | ``888b59e02a5b`` | ``3.2.0`` | Add bundle_name to callback table. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``888b59e02a5b`` | ``6222ce48e289`` | ``3.2.0`` | Fix migration file ORM inconsistencies. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def set_xcom(
),
] = None,
map_index: Annotated[int, Query()] = -1,
dag_result: Annotated[bool, Query(description="Whether this XCom is a dag result")] = False,
mapped_length: Annotated[
int | None, Query(description="Number of mapped tasks this value expands into")
] = None,
Expand Down Expand Up @@ -397,6 +398,7 @@ def set_xcom(
dag_id=dag_id,
map_index=map_index,
serialize=False,
dag_result=dag_result,
session=session,
)
except ValueError as e:
Expand Down
9 changes: 8 additions & 1 deletion airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,14 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
resp = xcom
elif isinstance(msg, SetXCom):
self.client.xcoms.set(
msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length
msg.dag_id,
msg.run_id,
msg.task_id,
msg.key,
msg.value,
msg.map_index,
dag_result=msg.dag_result,
mapped_length=msg.mapped_length,
)
elif isinstance(msg, GetDRCount):
dr_count = self.client.dag_runs.get_count(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Add dag_result to XComModel.

Revision ID: a4c2d171ae18
Revises: 1d6611b6ab7c
Create Date: 2026-03-17 00:23:45.305588

"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "a4c2d171ae18"
down_revision = "1d6611b6ab7c"
branch_labels = None
depends_on = None
airflow_version = "3.3.0"


def upgrade():
"""Add dag_result to XComModel."""
with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.add_column(sa.Column("dag_result", sa.Boolean, nullable=True))


def downgrade():
"""Remove dag_result from XComModel."""
with op.batch_alter_table("xcom", schema=None) as batch_op:
batch_op.drop_column("dag_result")
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from sqlalchemy import (
JSON,
Boolean,
ForeignKeyConstraint,
Index,
Integer,
Expand Down Expand Up @@ -66,6 +67,7 @@ class XComModel(TaskInstanceDependencies):
task_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True)
map_index: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, server_default="-1")
key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True)
dag_result: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False)

# Denormalized for easier lookup.
dag_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
Expand Down Expand Up @@ -163,6 +165,7 @@ def set(
run_id: str,
map_index: int = -1,
serialize: bool = True,
dag_result: bool = False,
session: Session = NEW_SESSION,
) -> None:
"""
Expand Down Expand Up @@ -241,6 +244,7 @@ def set(
task_id=task_id,
dag_id=dag_id,
map_index=map_index,
dag_result=dag_result,
)
session.add(new)
session.flush()
Expand Down
1 change: 1 addition & 0 deletions airflow-core/src/airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class MappedClassProtocol(Protocol):
"3.1.0": "cc92b33c6709",
"3.1.8": "509b94a1042d",
"3.2.0": "1d6611b6ab7c",
"3.3.0": "a4c2d171ae18",
}

# Prefix used to identify tables holding data moved during migration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,26 @@ def test_xcom_roundtrip(self, client, create_task_instance, session, value, expe
assert response.status_code == 200
assert XComResponse.model_validate_json(response.read()).value == expected_value

def test_xcom_dag_result(self, client, create_task_instance, session):
"""
Test that the dag_result flag propagates to XComModel.
"""
ti = create_task_instance()
client.post(
f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/return_value",
params={"dag_result": True},
json=123,
)

dag_result = session.scalar(
select(XComModel.dag_result).where(
XComModel.task_id == ti.task_id,
XComModel.dag_id == ti.dag_id,
XComModel.key == "return_value",
)
)
assert dag_result is True


class TestXComsDeleteEndpoint:
def test_xcom_delete_endpoint(self, client, create_task_instance, session):
Expand Down
11 changes: 8 additions & 3 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,14 +529,18 @@ def set(
key: str,
value,
map_index: int | None = None,
*,
dag_result: bool = False,
mapped_length: int | None = None,
) -> OKResponse:
"""Set a XCom value via the API server."""
# TODO: check if we need to use map_index as params in the uri
# ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
params = {}
params: dict[str, Any] = {}
if dag_result:
params["dag_result"] = dag_result
if map_index is not None and map_index >= 0:
params = {"map_index": map_index}
params["map_index"] = map_index
if mapped_length is not None and mapped_length >= 0:
params["mapped_length"] = mapped_length
self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value)
Expand All @@ -554,9 +558,10 @@ def delete(
map_index: int | None = None,
) -> OKResponse:
"""Delete a XCom with given key via the API server."""
params = {}
if map_index is not None and map_index >= 0:
params = {"map_index": map_index}
else:
params = {}
self.client.delete(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params)
# Any error from the server will anyway be propagated down to the supervisor,
# so we choose to send a generic response to the supervisor over the server response to
Expand Down
2 changes: 2 additions & 0 deletions task-sdk/src/airflow/sdk/bases/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def set(
task_id: str,
run_id: str,
map_index: int = -1,
dag_result: bool = False,
_mapped_length: int | None = None,
) -> None:
"""
Expand Down Expand Up @@ -91,6 +92,7 @@ def set(
task_id=task_id,
run_id=run_id,
map_index=map_index,
dag_result=dag_result,
mapped_length=_mapped_length,
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class AbstractOperator(Templater, DAGNode):
_on_failure_fail_dagrun = False
is_setup: bool = False
is_teardown: bool = False
returns_dag_result: bool = False

HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
(
Expand Down
16 changes: 14 additions & 2 deletions task-sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from dateutil.relativedelta import relativedelta

from airflow import settings
from airflow.sdk import TaskInstanceState, TriggerRule
from airflow.sdk import TaskInstanceState, TriggerRule, XComArg
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.bases.timetable import BaseTimetable
from airflow.sdk.definitions._internal.node import validate_key
Expand All @@ -62,7 +62,7 @@

if TYPE_CHECKING:
from re import Pattern
from typing import TypeAlias
from typing import TypeAlias, TypeVar

from pendulum.tz.timezone import FixedTimezone, Timezone
from typing_extensions import Self, TypeIs
Expand All @@ -78,6 +78,8 @@

Operator: TypeAlias = BaseOperator | MappedOperator

X = TypeVar("X", bound=XComArg)

log = logging.getLogger(__name__)

TAG_MAX_LEN = 100
Expand Down Expand Up @@ -1098,6 +1100,16 @@ def _remove_task(self, task_id: str) -> None:
if tg:
tg._remove(task)

def add_result(self, xcom_arg: X) -> X:
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions.xcom_arg import PlainXComArg

if not isinstance(xcom_arg, PlainXComArg) or xcom_arg.key != BaseXCom.XCOM_RETURN_KEY:
raise ValueError("Only plain return value can be used as dag result")

xcom_arg.operator.returns_dag_result = True
return xcom_arg

def check_cycle(self) -> None:
"""
Check to see if there are any cycles in the Dag.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING

from airflow.sdk.bases.decorator import TaskDecorator
from airflow.sdk.definitions.dag import dag
Expand All @@ -25,6 +25,9 @@
from airflow.sdk.definitions.decorators.task_group import task_group
from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime

if TYPE_CHECKING:
from collections.abc import Callable

# Please keep this in sync with the .pyi's __all__.
__all__ = [
"TaskDecorator",
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,7 @@ class SetXCom(BaseModel):
run_id: str
task_id: str
map_index: int | None = None
dag_result: bool = False
mapped_length: int | None = None
type: Literal["SetXCom"] = "SetXCom"

Expand Down
9 changes: 8 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
self.client.task_instances.skip_downstream_tasks(self.id, msg)
elif isinstance(msg, SetXCom):
self.client.xcoms.set(
msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length
msg.dag_id,
msg.run_id,
msg.task_id,
msg.key,
msg.value,
msg.map_index,
dag_result=msg.dag_result,
mapped_length=msg.mapped_length,
)
elif isinstance(msg, DeleteXCom):
self.client.xcoms.delete(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
Expand Down
9 changes: 8 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,13 @@ def mark_success_url(self) -> str:
return self.log_url


def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None:
def _xcom_push(
ti: RuntimeTaskInstance,
key: str,
value: Any,
*,
mapped_length: int | None = None,
) -> None:
"""Push a XCom through XCom.set, which pushes to XCom Backend if configured."""
# Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK
# consumers
Expand All @@ -715,6 +721,7 @@ def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int
task_id=ti.task_id,
run_id=ti.run_id,
map_index=ti.map_index,
dag_result=ti.task.returns_dag_result,
_mapped_length=mapped_length,
)

Expand Down
30 changes: 27 additions & 3 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,8 +1579,8 @@ class RequestTestCase:
"test_key",
'{"key": "test_key", "value": {"key2": "value2"}}',
None,
None,
),
kwargs={"dag_result": False, "mapped_length": None},
response=OKResponse(ok=True),
),
test_id="set_xcom",
Expand All @@ -1603,8 +1603,8 @@ class RequestTestCase:
"test_key",
'{"key": "test_key", "value": {"key2": "value2"}}',
2,
None,
),
kwargs={"dag_result": False, "mapped_length": None},
response=OKResponse(ok=True),
),
test_id="set_xcom_with_map_index",
Expand All @@ -1628,12 +1628,36 @@ class RequestTestCase:
"test_key",
'{"key": "test_key", "value": {"key2": "value2"}}',
2,
3,
),
kwargs={"dag_result": False, "mapped_length": 3},
response=OKResponse(ok=True),
),
test_id="set_xcom_with_map_index_and_mapped_length",
),
RequestTestCase(
message=SetXCom(
dag_id="test_dag",
run_id="test_run",
task_id="test_task",
key="test_key",
value='{"key": "test_key", "value": {"key2": "value2"}}',
dag_result=True,
),
client_mock=ClientMock(
method_path="xcoms.set",
args=(
"test_dag",
"test_run",
"test_task",
"test_key",
'{"key": "test_key", "value": {"key2": "value2"}}',
None,
),
kwargs={"dag_result": True, "mapped_length": None},
response=OKResponse(ok=True),
),
test_id="set_xcom_with_dag_result",
),
RequestTestCase(
message=DeleteXCom(
dag_id="test_dag",
Expand Down
Loading
Loading