diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py index ff42238806b12..45c8f73b75574 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -566,7 +566,6 @@ def wait_dag_run_until_finished( run_id=dag_run_id, interval=interval, result_task_ids=result_task_ids, - session=session, ) return StreamingResponse(waiter.wait()) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py index 110f34c780ead..e7d7cb98c939f 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py @@ -26,7 +26,6 @@ import attrs from sqlalchemy import select -from airflow.api_fastapi.common.db.common import SessionDep from airflow.models.dagrun import DagRun from airflow.models.xcom import XCOM_RETURN_KEY, XComModel from airflow.utils.session import create_session_async @@ -35,8 +34,6 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, Iterator - from sqlalchemy import ScalarResult - @attrs.define class DagRunWaiter: @@ -46,22 +43,22 @@ class DagRunWaiter: run_id: str interval: float result_task_ids: list[str] | None - session: SessionDep async def _get_dag_run(self) -> DagRun: async with create_session_async() as session: return await session.scalar(select(DagRun).filter_by(dag_id=self.dag_id, run_id=self.run_id)) - def _serialize_xcoms(self) -> dict[str, Any]: + async def _serialize_xcoms(self) -> dict[str, Any]: xcom_query = XComModel.get_many( run_id=self.run_id, key=XCOM_RETURN_KEY, task_ids=self.result_task_ids, dag_ids=self.dag_id, ) - xcom_results: ScalarResult[tuple[XComModel]] = self.session.scalars( - xcom_query.order_by(XComModel.task_id, XComModel.map_index) - ) + async with create_session_async() as session: + xcom_results = ( + await session.scalars(xcom_query.order_by(XComModel.task_id, XComModel.map_index)) + ).all() def _group_xcoms(g: Iterator[XComModel | tuple[XComModel]]) -> Any: entries = [row[0] if isinstance(row, tuple) else row for row in g] @@ -74,18 +71,18 @@ def _group_xcoms(g: Iterator[XComModel | tuple[XComModel]]) -> Any: for task_id, g in itertools.groupby(xcom_results, key=operator.attrgetter("task_id")) } - def _serialize_response(self, dag_run: DagRun) -> str: + async def _serialize_response(self, dag_run: DagRun) -> str: resp = {"state": dag_run.state} if dag_run.state not in State.finished_dr_states: return json.dumps(resp) if self.result_task_ids: - resp["results"] = self._serialize_xcoms() + resp["results"] = await self._serialize_xcoms() return json.dumps(resp) async def wait(self) -> AsyncGenerator[str, None]: - yield self._serialize_response(dag_run := await self._get_dag_run()) + yield await self._serialize_response(dag_run := await self._get_dag_run()) yield "\n" while dag_run.state not in State.finished_dr_states: await asyncio.sleep(self.interval) - yield self._serialize_response(dag_run := await self._get_dag_run()) + yield await self._serialize_response(dag_run := await self._get_dag_run()) yield "\n"