Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions airflow-core/src/airflow/models/dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,12 @@ def _latest_version_select(
if load_bundle_model:
query = query.options(joinedload(cls.bundle))

query = query.order_by(cls.created_at.desc()).limit(1)
# Order by version_number, not created_at: version_number is monotonic and unique per
# dag_id, so it is deterministic even when two versions share a created_at timestamp.
# write_dag relies on this select to compute the next version_number; ordering by
# created_at could pick a non-max row under a tie and collide with the
# (dag_id, version_number) unique constraint.
query = query.order_by(cls.version_number.desc()).limit(1)
return query

@classmethod
Expand Down Expand Up @@ -224,7 +229,7 @@ def get_version(
if version_number:
version_select_obj = version_select_obj.where(cls.version_number == version_number)

return session.scalar(version_select_obj.order_by(cls.id.desc()).limit(1))
return session.scalar(version_select_obj.order_by(cls.version_number.desc()).limit(1))

@property
def version(self) -> str:
Expand Down
14 changes: 8 additions & 6 deletions airflow-core/src/airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,15 +544,18 @@ def _prefetch_dag_write_metadata(
if not dag_id_list:
return {}

# Fetch latest serialized_dag (last_updated, dag_hash) per dag_id
# using a window function to pick the most recent row.
# Fetch the serialized_dag (last_updated, dag_hash) of the latest DagVersion per dag_id,
# ordering by version_number so it stays consistent with the DagVersion picked by dv_subq.
sd_subq = (
select(
cls.dag_id.label("dag_id"),
cls.last_updated.label("last_updated"),
cls.dag_hash.label("dag_hash"),
func.row_number().over(partition_by=cls.dag_id, order_by=cls.created_at.desc()).label("rn"),
func.row_number()
.over(partition_by=cls.dag_id, order_by=DagVersion.version_number.desc())
.label("rn"),
)
.join(DagVersion, cls.dag_version_id == DagVersion.id)
.where(cls.dag_id.in_(dag_id_list))
.subquery()
)
Expand All @@ -563,14 +566,13 @@ def _prefetch_dag_write_metadata(
row.dag_id: (row.last_updated, row.dag_hash) for row in sd_rows
}

# Fetch latest DagVersion per dag_id using a window function,
# matching the original write_dag ordering (ORDER BY created_at DESC).
# Fetch latest DagVersion per dag_id, ordering by version_number to match write_dag.
dv_subq = (
select(
DagVersion.id.label("id"),
DagVersion.dag_id.label("dag_id"),
func.row_number()
.over(partition_by=DagVersion.dag_id, order_by=DagVersion.created_at.desc())
.over(partition_by=DagVersion.dag_id, order_by=DagVersion.version_number.desc())
.label("rn"),
)
.where(DagVersion.dag_id.in_(dag_id_list))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from airflow._shared.timezones import timezone
from airflow.models.dag import DagModel
from airflow.models.dag_version import DagVersion
from airflow.models.dagbag import DBDagBag
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -643,15 +644,17 @@ def test_get_grid_runs(self, session, test_client):
assert _strip_dag_version_ids(response.json()) == [GRID_RUN_1, GRID_RUN_2]

def test_get_grid_runs_multiple_dag_versions(self, session, test_client):
latest_dag_version = session.scalar(select(DagModel).where(DagModel.dag_id == DAG_ID_5)).dag_versions[
-1
]
latest_task_instance = session.scalar(
# run_5_2 is created after version 2 exists, so its task instances run on version 2.
# Reassign one of them to version 1 so the run spans two versions.
first_dag_version = session.scalar(
select(DagVersion).where(DagVersion.dag_id == DAG_ID_5, DagVersion.version_number == 1)
)
task_instance = session.scalar(
select(TaskInstance)
.where(TaskInstance.dag_id == DAG_ID_5, TaskInstance.run_id == "run_5_2")
.limit(1)
)
latest_task_instance.dag_version = latest_dag_version
task_instance.dag_version = first_dag_version
session.commit()

response = test_client.get(f"/grid/runs/{DAG_ID_5}?limit=5")
Expand Down
61 changes: 60 additions & 1 deletion airflow-core/tests/unit/models/test_dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,32 @@
# under the License.
from __future__ import annotations

from datetime import timedelta

import pytest
from sqlalchemy import func, select

from airflow._shared.timezones import timezone
from airflow.models.dag import DagModel
from airflow.models.dag_version import DagVersion
from airflow.models.dagbundle import DagBundleModel
from airflow.providers.standard.operators.empty import EmptyOperator

from tests_common.test_utils.dag import sync_dag_to_db
from tests_common.test_utils.db import clear_db_dags
from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_dags

pytestmark = pytest.mark.db_test


class TestDagVersion:
def setup_method(self):
clear_db_dags()
clear_db_dag_bundles()

def teardown_method(self):
# clear_db_dags() first: DagModel.bundle_name has an FK to dag_bundle.
clear_db_dags()
clear_db_dag_bundles()

@pytest.mark.need_serialized_dag
def test_writing_dag_version(self, dag_maker, session):
Expand All @@ -59,6 +67,57 @@ def test_writing_dag_version_with_changes(self, dag_maker, session):
assert latest_version.version_number == 2
assert session.scalar(select(func.count()).where(DagVersion.dag_id == dag.dag_id)) == 2

@staticmethod
def _seed_two_versions_with_inverted_created_at(session, *, dag_id):
"""Create versions 1 and 2 where version 2 has an *earlier* created_at than version 1.

This makes created_at ordering disagree with version_number ordering, modelling the
timestamp tie / clock-skew case the ordering must be robust to. Returns the bundle name.
"""
bundle_name = f"bundle-{dag_id}"
session.add(DagBundleModel(name=bundle_name))
session.flush()
session.add(DagModel(dag_id=dag_id, bundle_name=bundle_name))
session.flush()

base = timezone.utcnow()
for version_number, created_at in ((1, base), (2, base - timedelta(minutes=1))):
session.add(
DagVersion(
dag_id=dag_id,
version_number=version_number,
bundle_name=bundle_name,
created_at=created_at,
last_updated=created_at,
)
)
session.commit()
return bundle_name

def test_latest_version_uses_version_number_not_created_at(self, session):
"""The latest version is the one with the highest version_number, not the latest created_at."""
dag_id = "test_latest_ordering"
self._seed_two_versions_with_inverted_created_at(session, dag_id=dag_id)

assert DagVersion.get_latest_version(dag_id, session=session).version_number == 2
assert DagVersion.get_version(dag_id, session=session).version_number == 2

def test_write_dag_increments_from_max_version_number(self, session):
"""write_dag must increment from the max version_number, not the latest-created row.

Otherwise, when created_at ordering disagrees with version_number ordering, it would
recompute an already-used version_number and violate the (dag_id, version_number) unique
constraint.
"""
dag_id = "test_write_dag_increment"
bundle_name = self._seed_two_versions_with_inverted_created_at(session, dag_id=dag_id)

new_version = DagVersion.write_dag(dag_id=dag_id, bundle_name=bundle_name, session=session)
session.commit()

assert new_version.version_number == 3
assert session.scalar(select(func.count()).where(DagVersion.dag_id == dag_id)) == 3

@pytest.mark.need_serialized_dag
def test_get_version(self, dag_maker, session):
"""The two dags have the same version name and number but different dag ids"""
Expand Down
Loading