diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 041c701ec4..0120b88f73 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -2587,10 +2587,17 @@ def _sync_execution( self, execution: FlyteWorkflowExecution, sync_nodes: bool = False, + _depth: int = 0, + _max_depth: int = 50, ) -> FlyteWorkflowExecution: """ Sync a FlyteWorkflowExecution object with its corresponding remote state. """ + if _depth > _max_depth: + raise FlyteAssertion( + f"Nesting depth {_depth} exceeds _max_depth={_max_depth} for execution " + f"{execution.id}. Refusing to recurse further to avoid RecursionError." + ) # Update closure, and then data, because we don't want the execution to finish between when we get the data, # and then for the closure to have is_done to be true. execution._closure = self.client.get_execution(execution.id).closure @@ -2642,7 +2649,9 @@ def _sync_execution( if sync_nodes: node_execs = {} for n in underlying_node_executions: - node_execs[n.id.node_id] = self.sync_node_execution(n, node_mapping) # noqa + node_execs[n.id.node_id] = self.sync_node_execution( # noqa + n, node_mapping, _depth=_depth + 1, _max_depth=_max_depth + ) execution._node_executions = node_execs return self._assign_inputs_and_outputs(execution, execution_data, node_interface) @@ -2650,6 +2659,8 @@ def sync_node_execution( self, execution: FlyteNodeExecution, node_mapping: typing.Dict[str, FlyteNode], + _depth: int = 0, + _max_depth: int = 50, ) -> FlyteNodeExecution: """ Get data backing a node execution. These FlyteNodeExecution objects should've come from Admin with the model @@ -2669,6 +2680,11 @@ def sync_node_execution( The data model is complicated, so ascertaining which of these happened is a bit tricky. That logic is encapsulated in this function. """ + if _depth > _max_depth: + raise FlyteAssertion( + f"Nesting depth {_depth} exceeds _max_depth={_max_depth} for node " + f"{execution.id}. Refusing to recurse further to avoid RecursionError." + ) # For single task execution - the metadata spec node id is missing. In these cases, revert to regular node id node_id = execution.metadata.spec_node_id # This case supports single-task execution compiled workflows. @@ -2712,7 +2728,7 @@ def sync_node_execution( launched_exec = self.fetch_execution( project=launched_exec_id.project, domain=launched_exec_id.domain, name=launched_exec_id.name ) - self.sync_execution(launched_exec, sync_nodes=True) + self._sync_execution(launched_exec, sync_nodes=True, _depth=_depth + 1, _max_depth=_max_depth) if launched_exec.is_done: # The synced underlying execution should've had these populated. execution._inputs = launched_exec.inputs @@ -2743,7 +2759,12 @@ def sync_node_execution( dynamic_flyte_wf = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) execution._underlying_node_executions = [ - self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), dynamic_flyte_wf._node_map) + self.sync_node_execution( + FlyteNodeExecution.promote_from_model(cne), + dynamic_flyte_wf._node_map, + _depth=_depth + 1, + _max_depth=_max_depth, + ) for cne in child_node_executions ] execution._task_executions = [ @@ -2757,7 +2778,12 @@ def sync_node_execution( sub_flyte_workflow = execution._node.flyte_entity sub_node_mapping = {n.id: n for n in sub_flyte_workflow.flyte_nodes} execution._underlying_node_executions = [ - self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping) + self.sync_node_execution( + FlyteNodeExecution.promote_from_model(cne), + sub_node_mapping, + _depth=_depth + 1, + _max_depth=_max_depth, + ) for cne in child_node_executions ] execution._interface = sub_flyte_workflow.interface @@ -2778,7 +2804,12 @@ def sync_node_execution( sub_node_mapping[else_node.id] = else_node execution._underlying_node_executions = [ - self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping) + self.sync_node_execution( + FlyteNodeExecution.promote_from_model(cne), + sub_node_mapping, + _depth=_depth + 1, + _max_depth=_max_depth, + ) for cne in child_node_executions ] else: @@ -2851,7 +2882,12 @@ def sync_node_execution( sub_node_mapping[else_node.id] = else_node execution._underlying_node_executions = [ - self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping) + self.sync_node_execution( + FlyteNodeExecution.promote_from_model(cne), + sub_node_mapping, + _depth=_depth + 1, + _max_depth=_max_depth, + ) for cne in child_node_executions ] diff --git a/tests/flytekit/unit/remote/test_recursion_guard.py b/tests/flytekit/unit/remote/test_recursion_guard.py new file mode 100644 index 0000000000..2413d38e5b --- /dev/null +++ b/tests/flytekit/unit/remote/test_recursion_guard.py @@ -0,0 +1,115 @@ +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from flytekit.configuration import Config +from flytekit.exceptions.user import FlyteAssertion +from flytekit.models.core.identifier import NodeExecutionIdentifier, WorkflowExecutionIdentifier +from flytekit.models.node_execution import ( + NodeExecutionClosure, + WorkflowNodeMetadata, +) +from flytekit.remote.executions import FlyteNodeExecution, FlyteWorkflowExecution +from flytekit.remote.interface import TypedInterface +from flytekit.remote.remote import FlyteRemote + + +def _mock_launched_exec(): + exec = MagicMock(spec=FlyteWorkflowExecution) + wf = MagicMock() + wf.interface = MagicMock(spec=TypedInterface) + exec._flyte_workflow = wf + exec.is_done = False + exec.inputs = {"a": 1} + exec.outputs = {"b": 2} + return exec + + +def _make_node_execution(node_id: str, execution_name: str, with_workflow_node_metadata: bool = False): + wf_exec_id = WorkflowExecutionIdentifier("p1", "d1", execution_name) + ne_id = NodeExecutionIdentifier(node_id, wf_exec_id) + meta = MagicMock() + meta.is_parent_node = False + meta.is_array = False + meta.spec_node_id = node_id + + wf_node_meta = None + if with_workflow_node_metadata: + wf_node_meta = WorkflowNodeMetadata( + execution_id=WorkflowExecutionIdentifier("p1", "d1", f"launched_{execution_name}") + ) + + closure = NodeExecutionClosure( + phase=0, + started_at=datetime.now(timezone.utc), + duration=timedelta(seconds=1), + workflow_node_metadata=wf_node_meta, + ) + return FlyteNodeExecution(id=ne_id, input_uri="s3://bucket/input", closure=closure, metadata=meta) + + +@pytest.fixture +def remote(): + with patch("flytekit.clients.friendly.SynchronousFlyteClient"): + flyte_remote = FlyteRemote( + config=Config.auto(), + default_project="p1", + default_domain="d1", + ) + flyte_remote._client_initialized = True + flyte_remote._client = MagicMock() + return flyte_remote + + +def test_max_depth_raises_flyte_assertion(remote): + ne = _make_node_execution("n1", "exec1", with_workflow_node_metadata=True) + + launched_exec = _mock_launched_exec() + remote.fetch_execution = MagicMock(return_value=launched_exec) + remote.client.get_node_execution_data = MagicMock( + return_value=MagicMock(dynamic_workflow=None) + ) + + # Don't mock _sync_execution — the guard fires inside it when _depth exceeds _max_depth. + # With _max_depth=0, the initial sync_node_execution enters at _depth=0 (passes guard). + # It reaches the launched LP path, which calls _sync_execution with _depth=1. + # _sync_execution then sees 1 > 0 and raises FlyteAssertion. + with pytest.raises(FlyteAssertion, match="Nesting depth"): + remote.sync_node_execution(ne, {"n1": MagicMock()}, _max_depth=0) + + +def test_reasonable_depth_does_not_raise(remote): + ne = _make_node_execution("n1", "exec1", with_workflow_node_metadata=True) + + launched_exec = _mock_launched_exec() + remote.fetch_execution = MagicMock(return_value=launched_exec) + remote.client.get_node_execution_data = MagicMock( + return_value=MagicMock(dynamic_workflow=None) + ) + remote._sync_execution = MagicMock() + + result = remote.sync_node_execution(ne, {"n1": MagicMock()}) + assert result is ne + + +def test_nested_under_default_limit(remote): + ne = _make_node_execution("n1", "exec1", with_workflow_node_metadata=True) + + launched_exec = _mock_launched_exec() + remote.fetch_execution = MagicMock(return_value=launched_exec) + remote.client.get_node_execution_data = MagicMock( + return_value=MagicMock(dynamic_workflow=None) + ) + remote._sync_execution = MagicMock() + + result = remote.sync_node_execution(ne, {"n1": MagicMock()}, _depth=1, _max_depth=50) + assert result is ne + + +def test_sync_execution_depth_guard(remote): + wf_exec = MagicMock(spec=FlyteWorkflowExecution) + wf_exec.id = WorkflowExecutionIdentifier("p1", "d1", "deep_exec") + + with pytest.raises(FlyteAssertion, match="Nesting depth"): + remote._sync_execution(wf_exec, sync_nodes=True, _depth=51, _max_depth=50)