Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,6 @@ def wait_dag_run_until_finished(
run_id=dag_run_id,
interval=interval,
result_task_ids=result_task_ids,
session=session,
)
return StreamingResponse(waiter.wait())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import attrs
from sqlalchemy import select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.models.dagrun import DagRun
from airflow.models.xcom import XCOM_RETURN_KEY, XComModel
from airflow.utils.session import create_session_async
Expand All @@ -35,8 +34,6 @@
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterator

from sqlalchemy import ScalarResult


@attrs.define
class DagRunWaiter:
Expand All @@ -46,22 +43,22 @@ class DagRunWaiter:
run_id: str
interval: float
result_task_ids: list[str] | None
session: SessionDep

async def _get_dag_run(self) -> DagRun:
async with create_session_async() as session:
return await session.scalar(select(DagRun).filter_by(dag_id=self.dag_id, run_id=self.run_id))

def _serialize_xcoms(self) -> dict[str, Any]:
async def _serialize_xcoms(self) -> dict[str, Any]:
xcom_query = XComModel.get_many(
run_id=self.run_id,
key=XCOM_RETURN_KEY,
task_ids=self.result_task_ids,
dag_ids=self.dag_id,
)
xcom_results: ScalarResult[tuple[XComModel]] = self.session.scalars(
xcom_query.order_by(XComModel.task_id, XComModel.map_index)
)
async with create_session_async() as session:
xcom_results = (
await session.scalars(xcom_query.order_by(XComModel.task_id, XComModel.map_index))
).all()

def _group_xcoms(g: Iterator[XComModel | tuple[XComModel]]) -> Any:
entries = [row[0] if isinstance(row, tuple) else row for row in g]
Expand All @@ -74,18 +71,18 @@ def _group_xcoms(g: Iterator[XComModel | tuple[XComModel]]) -> Any:
for task_id, g in itertools.groupby(xcom_results, key=operator.attrgetter("task_id"))
}

def _serialize_response(self, dag_run: DagRun) -> str:
async def _serialize_response(self, dag_run: DagRun) -> str:
resp = {"state": dag_run.state}
if dag_run.state not in State.finished_dr_states:
return json.dumps(resp)
if self.result_task_ids:
resp["results"] = self._serialize_xcoms()
resp["results"] = await self._serialize_xcoms()
return json.dumps(resp)

async def wait(self) -> AsyncGenerator[str, None]:
yield self._serialize_response(dag_run := await self._get_dag_run())
yield await self._serialize_response(dag_run := await self._get_dag_run())
yield "\n"
while dag_run.state not in State.finished_dr_states:
await asyncio.sleep(self.interval)
yield self._serialize_response(dag_run := await self._get_dag_run())
yield await self._serialize_response(dag_run := await self._get_dag_run())
yield "\n"
Loading