Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2587,10 +2587,17 @@ def _sync_execution(
self,
execution: FlyteWorkflowExecution,
sync_nodes: bool = False,
_depth: int = 0,
_max_depth: int = 50,
) -> FlyteWorkflowExecution:
"""
Sync a FlyteWorkflowExecution object with its corresponding remote state.
"""
if _depth > _max_depth:
raise FlyteAssertion(
f"Nesting depth {_depth} exceeds _max_depth={_max_depth} for execution "
f"{execution.id}. Refusing to recurse further to avoid RecursionError."
)
# Update closure, and then data, because we don't want the execution to finish between when we get the data,
# and then for the closure to have is_done to be true.
execution._closure = self.client.get_execution(execution.id).closure
Expand Down Expand Up @@ -2642,14 +2649,18 @@ def _sync_execution(
if sync_nodes:
node_execs = {}
for n in underlying_node_executions:
node_execs[n.id.node_id] = self.sync_node_execution(n, node_mapping) # noqa
node_execs[n.id.node_id] = self.sync_node_execution( # noqa
n, node_mapping, _depth=_depth + 1, _max_depth=_max_depth
)
execution._node_executions = node_execs
return self._assign_inputs_and_outputs(execution, execution_data, node_interface)

def sync_node_execution(
self,
execution: FlyteNodeExecution,
node_mapping: typing.Dict[str, FlyteNode],
_depth: int = 0,
_max_depth: int = 50,
) -> FlyteNodeExecution:
"""
Get data backing a node execution. These FlyteNodeExecution objects should've come from Admin with the model
Expand All @@ -2669,6 +2680,11 @@ def sync_node_execution(
The data model is complicated, so ascertaining which of these happened is a bit tricky. That logic is
encapsulated in this function.
"""
if _depth > _max_depth:
raise FlyteAssertion(
f"Nesting depth {_depth} exceeds _max_depth={_max_depth} for node "
f"{execution.id}. Refusing to recurse further to avoid RecursionError."
)
# For single task execution - the metadata spec node id is missing. In these cases, revert to regular node id
node_id = execution.metadata.spec_node_id
# This case supports single-task execution compiled workflows.
Expand Down Expand Up @@ -2712,7 +2728,7 @@ def sync_node_execution(
launched_exec = self.fetch_execution(
project=launched_exec_id.project, domain=launched_exec_id.domain, name=launched_exec_id.name
)
self.sync_execution(launched_exec, sync_nodes=True)
self._sync_execution(launched_exec, sync_nodes=True, _depth=_depth + 1, _max_depth=_max_depth)
if launched_exec.is_done:
# The synced underlying execution should've had these populated.
execution._inputs = launched_exec.inputs
Expand Down Expand Up @@ -2743,7 +2759,12 @@ def sync_node_execution(

dynamic_flyte_wf = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans)
execution._underlying_node_executions = [
self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), dynamic_flyte_wf._node_map)
self.sync_node_execution(
FlyteNodeExecution.promote_from_model(cne),
dynamic_flyte_wf._node_map,
_depth=_depth + 1,
_max_depth=_max_depth,
)
for cne in child_node_executions
]
execution._task_executions = [
Expand All @@ -2757,7 +2778,12 @@ def sync_node_execution(
sub_flyte_workflow = execution._node.flyte_entity
sub_node_mapping = {n.id: n for n in sub_flyte_workflow.flyte_nodes}
execution._underlying_node_executions = [
self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping)
self.sync_node_execution(
FlyteNodeExecution.promote_from_model(cne),
sub_node_mapping,
_depth=_depth + 1,
_max_depth=_max_depth,
)
for cne in child_node_executions
]
execution._interface = sub_flyte_workflow.interface
Expand All @@ -2778,7 +2804,12 @@ def sync_node_execution(
sub_node_mapping[else_node.id] = else_node

execution._underlying_node_executions = [
self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping)
self.sync_node_execution(
FlyteNodeExecution.promote_from_model(cne),
sub_node_mapping,
_depth=_depth + 1,
_max_depth=_max_depth,
)
for cne in child_node_executions
]
else:
Expand Down Expand Up @@ -2851,7 +2882,12 @@ def sync_node_execution(
sub_node_mapping[else_node.id] = else_node

execution._underlying_node_executions = [
self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping)
self.sync_node_execution(
FlyteNodeExecution.promote_from_model(cne),
sub_node_mapping,
_depth=_depth + 1,
_max_depth=_max_depth,
)
for cne in child_node_executions
]

Expand Down
115 changes: 115 additions & 0 deletions tests/flytekit/unit/remote/test_recursion_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch

import pytest

from flytekit.configuration import Config
from flytekit.exceptions.user import FlyteAssertion
from flytekit.models.core.identifier import NodeExecutionIdentifier, WorkflowExecutionIdentifier
from flytekit.models.node_execution import (
NodeExecutionClosure,
WorkflowNodeMetadata,
)
from flytekit.remote.executions import FlyteNodeExecution, FlyteWorkflowExecution
from flytekit.remote.interface import TypedInterface
from flytekit.remote.remote import FlyteRemote


def _mock_launched_exec():
exec = MagicMock(spec=FlyteWorkflowExecution)
wf = MagicMock()
wf.interface = MagicMock(spec=TypedInterface)
exec._flyte_workflow = wf
exec.is_done = False
exec.inputs = {"a": 1}
exec.outputs = {"b": 2}
return exec


def _make_node_execution(node_id: str, execution_name: str, with_workflow_node_metadata: bool = False):
wf_exec_id = WorkflowExecutionIdentifier("p1", "d1", execution_name)
ne_id = NodeExecutionIdentifier(node_id, wf_exec_id)
meta = MagicMock()
meta.is_parent_node = False
meta.is_array = False
meta.spec_node_id = node_id

wf_node_meta = None
if with_workflow_node_metadata:
wf_node_meta = WorkflowNodeMetadata(
execution_id=WorkflowExecutionIdentifier("p1", "d1", f"launched_{execution_name}")
)

closure = NodeExecutionClosure(
phase=0,
started_at=datetime.now(timezone.utc),
duration=timedelta(seconds=1),
workflow_node_metadata=wf_node_meta,
)
return FlyteNodeExecution(id=ne_id, input_uri="s3://bucket/input", closure=closure, metadata=meta)


@pytest.fixture
def remote():
with patch("flytekit.clients.friendly.SynchronousFlyteClient"):
flyte_remote = FlyteRemote(
config=Config.auto(),
default_project="p1",
default_domain="d1",
)
flyte_remote._client_initialized = True
flyte_remote._client = MagicMock()
return flyte_remote


def test_max_depth_raises_flyte_assertion(remote):
ne = _make_node_execution("n1", "exec1", with_workflow_node_metadata=True)

launched_exec = _mock_launched_exec()
remote.fetch_execution = MagicMock(return_value=launched_exec)
remote.client.get_node_execution_data = MagicMock(
return_value=MagicMock(dynamic_workflow=None)
)

# Don't mock _sync_execution — the guard fires inside it when _depth exceeds _max_depth.
# With _max_depth=0, the initial sync_node_execution enters at _depth=0 (passes guard).
# It reaches the launched LP path, which calls _sync_execution with _depth=1.
# _sync_execution then sees 1 > 0 and raises FlyteAssertion.
with pytest.raises(FlyteAssertion, match="Nesting depth"):
remote.sync_node_execution(ne, {"n1": MagicMock()}, _max_depth=0)


def test_reasonable_depth_does_not_raise(remote):
ne = _make_node_execution("n1", "exec1", with_workflow_node_metadata=True)

launched_exec = _mock_launched_exec()
remote.fetch_execution = MagicMock(return_value=launched_exec)
remote.client.get_node_execution_data = MagicMock(
return_value=MagicMock(dynamic_workflow=None)
)
remote._sync_execution = MagicMock()

result = remote.sync_node_execution(ne, {"n1": MagicMock()})
assert result is ne


def test_nested_under_default_limit(remote):
ne = _make_node_execution("n1", "exec1", with_workflow_node_metadata=True)

launched_exec = _mock_launched_exec()
remote.fetch_execution = MagicMock(return_value=launched_exec)
remote.client.get_node_execution_data = MagicMock(
return_value=MagicMock(dynamic_workflow=None)
)
remote._sync_execution = MagicMock()

result = remote.sync_node_execution(ne, {"n1": MagicMock()}, _depth=1, _max_depth=50)
assert result is ne


def test_sync_execution_depth_guard(remote):
wf_exec = MagicMock(spec=FlyteWorkflowExecution)
wf_exec.id = WorkflowExecutionIdentifier("p1", "d1", "deep_exec")

with pytest.raises(FlyteAssertion, match="Nesting depth"):
remote._sync_execution(wf_exec, sync_nodes=True, _depth=51, _max_depth=50)