From d0e88798395a25a87411bb26aba9f296979545c7 Mon Sep 17 00:00:00 2001 From: joeyutong Date: Thu, 2 Apr 2026 09:34:56 +0800 Subject: [PATCH 1/6] Add Python reconcile-callable durable execution support --- python/flink_agents/api/runner_context.py | 46 +++ .../test_reconcile_fallback_exception.py | 34 ++ .../flink_agents/runtime/durable_execution.py | 74 ++++ .../runtime/flink_runner_context.py | 339 +++++++++++++--- python/flink_agents/runtime/local_runner.py | 16 +- .../runtime/tests/test_durable_execution.py | 56 ++- .../test_flink_runner_context_reconcilable.py | 364 ++++++++++++++++++ .../tests/test_local_runner_reconcilable.py | 72 ++++ .../runtime/context/RunnerContextImpl.java | 18 + 9 files changed, 969 insertions(+), 50 deletions(-) create mode 100644 python/flink_agents/api/tests/test_reconcile_fallback_exception.py create mode 100644 python/flink_agents/runtime/durable_execution.py create mode 100644 python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py create mode 100644 python/flink_agents/runtime/tests/test_local_runner_reconcilable.py diff --git a/python/flink_agents/api/runner_context.py b/python/flink_agents/api/runner_context.py index 7dfd54376..a58a5e403 100644 --- a/python/flink_agents/api/runner_context.py +++ b/python/flink_agents/api/runner_context.py @@ -28,6 +28,22 @@ from flink_agents.api.memory_object import MemoryObject +class ReconcileFallbackException(Exception): + """Signal that reconcile could not recover a terminal outcome. + + Raising this exception from a reconcile callable tells the runtime to fall back + to the original durable call instead of treating the reconcile failure as a + terminal outcome. + """ + + def __init__(self, cause: BaseException) -> None: + if not isinstance(cause, BaseException): + err_msg = "ReconcileFallbackException requires a BaseException cause" + raise TypeError(err_msg) + super().__init__(str(cause)) + self.cause = cause + + class AsyncExecutionResult: """This class wraps an asynchronous task that will be submitted to a thread pool only when awaited. This ensures lazy submission and serial execution semantics. @@ -196,6 +212,7 @@ def durable_execute( self, func: Callable[[Any], Any], *args: Any, + reconcile: Callable[[], Any] | None = None, **kwargs: Any, ) -> Any: """Synchronously execute the provided function with durable execution support. @@ -212,6 +229,16 @@ def durable_execute( will always make the durable_execute call with the same arguments and in the same order during job recovery. Otherwise, the behavior is undefined. + If `reconcile` is provided, it is used only when recovery resumes from a + `PENDING` durable call result. The reconcile callable should: + + * return the recovered terminal result when the previous invocation + succeeded + * raise the recovered terminal business exception when the previous + invocation failed + * raise `ReconcileFallbackException` when it cannot determine a terminal + outcome and the runtime should execute `func` + Usage:: def my_action(event, ctx): @@ -224,6 +251,10 @@ def my_action(event, ctx): The function to be executed. *args : Any Positional arguments to pass to the function. + reconcile : Callable[[], Any] | None + Optional zero-argument reconcile callable used only during recovery. + This is a reserved keyword-only parameter and is not forwarded to + `func`. **kwargs : Any Keyword arguments to pass to the function. @@ -238,6 +269,7 @@ def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, + reconcile: Callable[[], Any] | None = None, **kwargs: Any, ) -> "AsyncExecutionResult": """Asynchronously execute the provided function with durable execution support. @@ -251,6 +283,16 @@ def durable_execute_async( will always make the durable_execute_async call with the same arguments and in the same order during job recovery. Otherwise, the behavior is undefined. + If `reconcile` is provided, it is used only when recovery resumes from a + `PENDING` durable call result. The reconcile callable should: + + * return the recovered terminal result when the previous invocation + succeeded + * raise the recovered terminal business exception when the previous + invocation failed + * raise `ReconcileFallbackException` when it cannot determine a terminal + outcome and the runtime should execute `func` + Usage:: async def my_action(event, ctx): @@ -267,6 +309,10 @@ async def my_action(event, ctx): The function to be executed asynchronously. *args : Any Positional arguments to pass to the function. + reconcile : Callable[[], Any] | None + Optional zero-argument reconcile callable used only during recovery. + This is a reserved keyword-only parameter and is not forwarded to + `func`. **kwargs : Any Keyword arguments to pass to the function. diff --git a/python/flink_agents/api/tests/test_reconcile_fallback_exception.py b/python/flink_agents/api/tests/test_reconcile_fallback_exception.py new file mode 100644 index 000000000..3948a56a5 --- /dev/null +++ b/python/flink_agents/api/tests/test_reconcile_fallback_exception.py @@ -0,0 +1,34 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import pytest + +from flink_agents.api.runner_context import ReconcileFallbackException + + +def test_reconcile_fallback_exception_preserves_cause() -> None: + cause = ValueError("indeterminate") + + exc = ReconcileFallbackException(cause) + + assert exc.cause is cause + assert str(exc) == "indeterminate" + + +def test_reconcile_fallback_exception_requires_exception_cause() -> None: + with pytest.raises(TypeError, match="requires a BaseException cause"): + ReconcileFallbackException("not-an-exception") # type: ignore[arg-type] diff --git a/python/flink_agents/runtime/durable_execution.py b/python/flink_agents/runtime/durable_execution.py new file mode 100644 index 000000000..d4ac9c1f3 --- /dev/null +++ b/python/flink_agents/runtime/durable_execution.py @@ -0,0 +1,74 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import hashlib +import inspect +from typing import Any, Callable + +import cloudpickle + + +def _compute_function_id(func: Callable) -> str: + """Compute a stable function identifier from a callable.""" + module_obj = inspect.getmodule(func) + module = ( + module_obj.__name__ + if module_obj is not None + else getattr(func, "__module__", "") + ) + qualname = getattr(func, "__qualname__", getattr(func, "__name__", "")) + return f"{module}.{qualname}" + + +def _compute_args_digest(args: tuple, kwargs: dict) -> str: + """Compute a stable digest of the serialized arguments.""" + try: + serialized = cloudpickle.dumps((args, kwargs)) + return hashlib.sha256(serialized).hexdigest()[:16] + except Exception: + return hashlib.sha256(str((args, kwargs)).encode()).hexdigest()[:16] + + +def _can_bind_call( + func: Callable, + *args: Any, + **kwargs: Any, +) -> bool: + """Return whether the callable signature can bind the provided arguments.""" + try: + inspect.signature(func).bind(*args, **kwargs) + return True + except (TypeError, ValueError): + return False + + +def _validate_reconcile_callable( + reconcile: Callable[[], Any] | None, +) -> Callable[[], Any] | None: + """Validate that the reconcile callable is either absent or zero-argument.""" + if reconcile is None: + return None + + if not callable(reconcile): + err_msg = "reconcile must be callable" + raise TypeError(err_msg) + + if not _can_bind_call(reconcile): + err_msg = "reconcile must be a callable that takes no arguments" + raise TypeError(err_msg) + + return reconcile diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index 631267fae..20dcfda3f 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -15,11 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -import hashlib import logging import os from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Dict import cloudpickle from typing_extensions import override @@ -34,10 +34,18 @@ from flink_agents.api.memory_object import MemoryType from flink_agents.api.metric_group import MetricGroup from flink_agents.api.resource import Resource, ResourceType -from flink_agents.api.runner_context import AsyncExecutionResult, RunnerContext -from flink_agents.plan.agent_plan import AgentPlan +from flink_agents.api.runner_context import ( + AsyncExecutionResult, + ReconcileFallbackException, + RunnerContext, +) from flink_agents.runtime.flink_memory_object import FlinkMemoryObject from flink_agents.runtime.flink_metric_group import FlinkMetricGroup +from flink_agents.runtime.durable_execution import ( + _compute_args_digest, + _compute_function_id, + _validate_reconcile_callable, +) from flink_agents.runtime.memory.internal_base_long_term_memory import ( InternalBaseLongTermMemory, ) @@ -47,9 +55,28 @@ from flink_agents.runtime.python_java_utils import _build_event_log_string from flink_agents.runtime.resource_cache import ResourceCache +if TYPE_CHECKING: + from flink_agents.plan.agent_plan import AgentPlan + logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class _PersistedCallResult: + function_id: str + args_digest: str + status: str + result_payload: bytes | None + exception_payload: bytes | None + + +@dataclass(frozen=True) +class _ReconcileCallResolution: + kind: str + result: Any = None + exception: BaseException | None = None + + class _DurableExecutionResult: """Wrapper that holds result and triggers recording when unwrapped.""" @@ -155,28 +182,80 @@ def __await__(self) -> Any: return result -def _compute_function_id(func: Callable) -> str: - """Compute a stable function identifier from a callable. +class _ReconcileDurableAsyncExecutionResult(AsyncExecutionResult): + """An AsyncExecutionResult that resolves reconcile state on await.""" - Returns module.qualname for functions/methods. - """ - module = getattr(func, "__module__", "") - qualname = getattr(func, "__qualname__", getattr(func, "__name__", "")) - return f"{module}.{qualname}" + def __init__( + self, + ctx: "FlinkRunnerContext", + executor: Any, + func: Callable, + args: tuple, + reconcile: Callable[[], Any], + kwargs: dict, + ) -> None: + super().__init__(executor, func, args, kwargs) + self._ctx = ctx + self._reconcile = reconcile + + def __await__(self) -> Any: + resolution = self._ctx._resolve_reconcile_call( + self._func, + self._args, + self._reconcile, + self._kwargs, + ) + if resolution.kind == "replay": + result = self._ctx._replay_terminal_call(self._func, self._args, self._kwargs) + if False: + yield + return result -def _compute_args_digest(args: tuple, kwargs: dict) -> str: - """Compute a stable digest of the serialized arguments. + if resolution.kind == "finalize_success": + self._ctx._finalize_current_call( + self._func, + self._args, + self._kwargs, + resolution.result, + None, + ) + if False: + yield + return resolution.result + + if resolution.kind == "finalize_failure": + self._ctx._finalize_current_call( + self._func, + self._args, + self._kwargs, + None, + resolution.exception, + ) + raise resolution.exception - The digest is used to validate that the same arguments are passed - during recovery as during the original execution. - """ - try: - serialized = cloudpickle.dumps((args, kwargs)) - return hashlib.sha256(serialized).hexdigest()[:16] - except Exception: - # If serialization fails, return a fallback digest - return hashlib.sha256(str((args, kwargs)).encode()).hexdigest()[:16] + future = self._executor.submit(self._func, *self._args, **self._kwargs) + while not future.done(): + yield + + exception = None + result = None + try: + result = future.result() + except BaseException as e: + exception = e + + self._ctx._finalize_current_call( + self._func, + self._args, + self._kwargs, + result, + exception, + ) + + if exception is not None: + raise exception + return result class FlinkRunnerContext(RunnerContext): @@ -186,7 +265,7 @@ class FlinkRunnerContext(RunnerContext): durable execution support through execute() and execute_async() methods. """ - __agent_plan: AgentPlan | None + __agent_plan: Any __ltm: InternalBaseLongTermMemory = None def __init__( @@ -203,6 +282,8 @@ def __init__( j_runner_context : Any Java runner context used to synchronize data between Python and Java. """ + from flink_agents.plan.agent_plan import AgentPlan + self._j_runner_context = j_runner_context self.__agent_plan = AgentPlan.model_validate_json(agent_plan_json) self.__resource_cache = ResourceCache( @@ -409,11 +490,159 @@ def _record_call_completion( if "recordCallCompletion" not in str(e): logger.warning("Failed to record call completion: %s", e) + @staticmethod + def _serialize_call_payloads( + result: Any, + exception: BaseException | None, + ) -> tuple[bytes | None, bytes | None]: + result_payload = None if exception else cloudpickle.dumps(result) + exception_payload = cloudpickle.dumps(exception) if exception else None + return result_payload, exception_payload + + def _peek_current_call_result(self) -> _PersistedCallResult | None: + current = self._j_runner_context.getCurrentCallResultFields() + if current is None: + return None + + function_id, args_digest, status, result_payload, exception_payload = current + return _PersistedCallResult( + function_id=function_id, + args_digest=args_digest, + status=status, + result_payload=bytes(result_payload) if result_payload is not None else None, + exception_payload=( + bytes(exception_payload) if exception_payload is not None else None + ), + ) + + def _append_pending_call(self, func: Callable, args: tuple, kwargs: dict) -> None: + self._j_runner_context.appendPendingCall( + _compute_function_id(func), + _compute_args_digest(args, kwargs), + ) + + def _finalize_current_call( + self, + func: Callable, + args: tuple, + kwargs: dict, + result: Any, + exception: BaseException | None, + ) -> None: + function_id = _compute_function_id(func) + args_digest = _compute_args_digest(args, kwargs) + result_payload, exception_payload = self._serialize_call_payloads( + result, + exception, + ) + self._j_runner_context.finalizeCurrentCall( + function_id, + args_digest, + result_payload, + exception_payload, + ) + + def _clear_call_results_from_current_index_and_persist(self) -> None: + self._j_runner_context.clearCallResultsFromCurrentIndexAndPersist() + + def _replay_terminal_call(self, func: Callable, args: tuple, kwargs: dict) -> Any: + is_hit, cached_result = self._try_get_cached_result(func, args, kwargs) + if not is_hit: + err_msg = "Expected a terminal durable call result but replay did not hit" + raise RuntimeError(err_msg) + return cached_result + + def _resolve_reconcile_call( + self, + func: Callable, + args: tuple, + reconcile: Callable[[], Any], + kwargs: dict, + ) -> _ReconcileCallResolution: + function_id = _compute_function_id(func) + args_digest = _compute_args_digest(args, kwargs) + current = self._peek_current_call_result() + + if current is None: + self._append_pending_call(func, args, kwargs) + return _ReconcileCallResolution("execute") + + if current.function_id != function_id or current.args_digest != args_digest: + self._clear_call_results_from_current_index_and_persist() + self._append_pending_call(func, args, kwargs) + return _ReconcileCallResolution("execute") + + if current.status != "PENDING": + return _ReconcileCallResolution("replay") + + try: + result = reconcile() + except ReconcileFallbackException: + logger.warning( + "Reconcile fell back for Python durable call function_id=%s, args_digest=%s. " + "Falling back to re-execution.", + function_id, + args_digest, + exc_info=True, + ) + return _ReconcileCallResolution("execute") + except Exception as exception: + return _ReconcileCallResolution( + "finalize_failure", + exception=exception, + ) + + return _ReconcileCallResolution("finalize_success", result=result) + + def _execute_current_pending_call( + self, + func: Callable, + args: tuple, + kwargs: dict, + ) -> Any: + exception = None + result = None + try: + result = func(*args, **kwargs) + except BaseException as e: + exception = e + + self._finalize_current_call(func, args, kwargs, result, exception) + + if exception is not None: + raise exception + return result + + def _wrap_completion_only_func( + self, + func: Callable, + args: tuple, + kwargs: dict, + ) -> Callable[..., Any]: + def wrapped_func(*a: Any, **kw: Any) -> Any: + exception = None + result = None + try: + result = func(*a, **kw) + except BaseException as e: + exception = e + + if exception: + raise _DurableExecutionException( + func, args, kwargs, result, exception, self._record_call_completion + ) + return _DurableExecutionResult( + func, args, kwargs, result, self._record_call_completion + ) + + return wrapped_func + @override def durable_execute( self, func: Callable[[Any], Any], *args: Any, + reconcile: Callable[[], Any] | None = None, **kwargs: Any, ) -> Any: """Synchronously execute the provided function with durable execution support. @@ -426,6 +655,31 @@ def durable_execute( The function is executed synchronously in the current thread, blocking the operator until completion. """ + validated_reconcile = _validate_reconcile_callable(reconcile) + + if validated_reconcile is not None: + resolution = self._resolve_reconcile_call( + func, + args, + validated_reconcile, + kwargs, + ) + if resolution.kind == "replay": + return self._replay_terminal_call(func, args, kwargs) + if resolution.kind == "finalize_success": + self._finalize_current_call(func, args, kwargs, resolution.result, None) + return resolution.result + if resolution.kind == "finalize_failure": + self._finalize_current_call( + func, + args, + kwargs, + None, + resolution.exception, + ) + raise resolution.exception + return self._execute_current_pending_call(func, args, kwargs) + # Try to get cached result for recovery is_hit, cached_result = self._try_get_cached_result(func, args, kwargs) if is_hit: @@ -451,6 +705,7 @@ def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, + reconcile: Callable[[], Any] | None = None, **kwargs: Any, ) -> AsyncExecutionResult: """Asynchronously execute the provided function with durable execution support. @@ -464,32 +719,30 @@ def durable_execute_async( is awaited. Fire-and-forget calls (not awaiting the result) will NOT be recorded and cannot be recovered. """ + validated_reconcile = _validate_reconcile_callable(reconcile) + + if validated_reconcile is not None: + return _ReconcileDurableAsyncExecutionResult( + self, + self.executor, + func, + args, + validated_reconcile, + kwargs, + ) + # Try to get cached result for recovery is_hit, cached_result = self._try_get_cached_result(func, args, kwargs) if is_hit: # Return a pre-completed AsyncExecutionResult return _CachedAsyncExecutionResult(cached_result) - # Create a wrapper function that records completion - def wrapped_func(*a: Any, **kw: Any) -> Any: - exception = None - result = None - try: - result = func(*a, **kw) - except BaseException as e: - exception = e - - # Note: This runs in a thread pool, so we need to be careful - # The actual recording will happen when the result is awaited - if exception: - raise _DurableExecutionException( - func, args, kwargs, result, exception, self._record_call_completion - ) - return _DurableExecutionResult( - func, args, kwargs, result, self._record_call_completion - ) - - return _DurableAsyncExecutionResult(self.executor, wrapped_func, args, kwargs) + return _DurableAsyncExecutionResult( + self.executor, + self._wrap_completion_only_func(func, args, kwargs), + args, + kwargs, + ) @property @override diff --git a/python/flink_agents/runtime/local_runner.py b/python/flink_agents/runtime/local_runner.py index 75a73efb2..8f7e1ba65 100644 --- a/python/flink_agents/runtime/local_runner.py +++ b/python/flink_agents/runtime/local_runner.py @@ -20,7 +20,7 @@ import uuid from collections import deque from concurrent.futures import Future -from typing import Any, Callable, Dict, List +from typing import TYPE_CHECKING, Any, Callable, Dict, List from typing_extensions import override @@ -31,12 +31,14 @@ from flink_agents.api.metric_group import MetricGroup from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.runner_context import AsyncExecutionResult, RunnerContext -from flink_agents.plan.agent_plan import AgentPlan from flink_agents.plan.configuration import AgentConfiguration from flink_agents.runtime.agent_runner import AgentRunner from flink_agents.runtime.local_memory_object import LocalMemoryObject from flink_agents.runtime.resource_cache import ResourceCache +if TYPE_CHECKING: + from flink_agents.plan.agent_plan import AgentPlan + logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) @@ -58,7 +60,7 @@ class LocalRunnerContext(RunnerContext): Name of the action being executed. """ - __agent_plan: AgentPlan | None + __agent_plan: Any __key: Any events: deque[Event] action_name: str @@ -69,7 +71,7 @@ class LocalRunnerContext(RunnerContext): _config: AgentConfiguration def __init__( - self, agent_plan: AgentPlan, key: Any, config: AgentConfiguration + self, agent_plan: "AgentPlan", key: Any, config: AgentConfiguration ) -> None: """Initialize a new context with the given agent and key. @@ -190,6 +192,7 @@ def durable_execute( self, func: Callable[[Any], Any], *args: Any, + reconcile: Callable[[], Any] | None = None, **kwargs: Any, ) -> Any: """Synchronously execute the provided function. Access to memory @@ -208,6 +211,7 @@ def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, + reconcile: Callable[[], Any] | None = None, **kwargs: Any, ) -> AsyncExecutionResult: """Asynchronously execute the provided function. Access to memory @@ -271,7 +275,7 @@ class LocalRunner(AgentRunner): Internal configration. """ - __agent_plan: AgentPlan + __agent_plan: Any __keyed_contexts: Dict[Any, LocalRunnerContext] __outputs: List[Dict[str, Any]] __config: AgentConfiguration @@ -284,6 +288,8 @@ def __init__(self, agent: Agent, config: AgentConfiguration) -> None: agent : Agent The agent class to convert and run. """ + from flink_agents.plan.agent_plan import AgentPlan + self.__agent_plan = AgentPlan.from_agent(agent, config) self.__keyed_contexts = {} self.__outputs = [] diff --git a/python/flink_agents/runtime/tests/test_durable_execution.py b/python/flink_agents/runtime/tests/test_durable_execution.py index 080ad846f..b09d0b184 100644 --- a/python/flink_agents/runtime/tests/test_durable_execution.py +++ b/python/flink_agents/runtime/tests/test_durable_execution.py @@ -18,10 +18,12 @@ """Tests for durable execution helper functions.""" import cloudpickle +import pytest -from flink_agents.runtime.flink_runner_context import ( +from flink_agents.runtime.durable_execution import ( _compute_args_digest, _compute_function_id, + _validate_reconcile_callable, ) @@ -48,6 +50,19 @@ def class_method(cls, x: int) -> int: return x * 4 +class ReconcileCallables: + """Helpers for reconcile callable validation tests.""" + + def __init__(self, prefix: str) -> None: + self.prefix = prefix + + def bound_no_arg(self) -> str: + return f"bound:{self.prefix}" + + def requires_arg(self, value: int) -> str: + return f"{self.prefix}:{value}" + + def test_compute_function_id_for_function() -> None: """Test function ID computation for regular functions.""" func_id = _compute_function_id(sample_function) @@ -127,6 +142,44 @@ def test_compute_args_digest_kwargs_vs_args() -> None: assert digest1 != digest2 +def test_validate_reconcile_callable_accepts_none() -> None: + assert _validate_reconcile_callable(None) is None + + +def test_validate_reconcile_callable_accepts_zero_arg_function() -> None: + def reconcile() -> str: + return "ok" + + validated = _validate_reconcile_callable(reconcile) + + assert validated is reconcile + assert validated() == "ok" + + +def test_validate_reconcile_callable_accepts_bound_zero_arg_method() -> None: + callables = ReconcileCallables("client") + bound_method = callables.bound_no_arg + + validated = _validate_reconcile_callable(bound_method) + + assert validated is bound_method + assert validated() == "bound:client" + + +def test_validate_reconcile_callable_requires_callable() -> None: + with pytest.raises(TypeError, match="reconcile must be callable"): + _validate_reconcile_callable(1) # type: ignore[arg-type] + + +def test_validate_reconcile_callable_requires_zero_args() -> None: + callables = ReconcileCallables("client") + + with pytest.raises( + TypeError, match="reconcile must be a callable that takes no arguments" + ): + _validate_reconcile_callable(callables.requires_arg) + + def test_cloudpickle_serialization() -> None: """Test that results can be serialized and deserialized with cloudpickle.""" # Test basic types @@ -216,4 +269,3 @@ def test_cloudpickle_none_exception_message() -> None: assert isinstance(deserialized, RuntimeError) # str() of an exception with None message is "None" assert str(deserialized) == "None" - diff --git a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py new file mode 100644 index 000000000..8afaa1320 --- /dev/null +++ b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py @@ -0,0 +1,364 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import asyncio +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass + +import cloudpickle +import pytest + +from flink_agents.api.runner_context import ReconcileFallbackException +from flink_agents.runtime.durable_execution import ( + _compute_args_digest, + _compute_function_id, +) +from flink_agents.runtime.flink_runner_context import FlinkRunnerContext + + +@dataclass +class _StoredCallResult: + function_id: str + args_digest: str + status: str + result_payload: bytes | None = None + exception_payload: bytes | None = None + + +class _FakeJavaRunnerContext: + def __init__(self) -> None: + self.call_results: list[_StoredCallResult] = [] + self.current_call_index = 0 + self.operations: list[str] = [] + + def getCurrentCallResultFields(self): + self.operations.append("peek") + if self.current_call_index < len(self.call_results): + current = self.call_results[self.current_call_index] + return [ + current.function_id, + current.args_digest, + current.status, + current.result_payload, + current.exception_payload, + ] + return None + + def matchNextOrClearSubsequentCallResult(self, function_id: str, args_digest: str): + self.operations.append("match") + if self.current_call_index < len(self.call_results): + current = self.call_results[self.current_call_index] + if ( + current.function_id == function_id + and current.args_digest == args_digest + ): + self.current_call_index += 1 + return [True, current.result_payload, current.exception_payload] + self.call_results = self.call_results[: self.current_call_index] + return None + + def recordCallCompletion( + self, + function_id: str, + args_digest: str, + result_payload: bytes | None, + exception_payload: bytes | None, + ) -> None: + self.operations.append("record") + status = "FAILED" if exception_payload is not None else "SUCCEEDED" + self.call_results.append( + _StoredCallResult( + function_id=function_id, + args_digest=args_digest, + status=status, + result_payload=result_payload, + exception_payload=exception_payload, + ) + ) + self.current_call_index += 1 + + def appendPendingCall(self, function_id: str, args_digest: str) -> None: + self.operations.append("append_pending") + self.call_results.append( + _StoredCallResult( + function_id=function_id, + args_digest=args_digest, + status="PENDING", + ) + ) + + def finalizeCurrentCall( + self, + function_id: str, + args_digest: str, + result_payload: bytes | None, + exception_payload: bytes | None, + ) -> None: + self.operations.append("finalize") + current = self.call_results[self.current_call_index] + assert current.status == "PENDING" + assert current.function_id == function_id + assert current.args_digest == args_digest + self.call_results[self.current_call_index] = _StoredCallResult( + function_id=function_id, + args_digest=args_digest, + status="FAILED" if exception_payload is not None else "SUCCEEDED", + result_payload=result_payload, + exception_payload=exception_payload, + ) + self.current_call_index += 1 + + def clearCallResultsFromCurrentIndexAndPersist(self) -> None: + self.operations.append("clear") + self.call_results = self.call_results[: self.current_call_index] + + +def _create_runner_context( + j_runner_context: _FakeJavaRunnerContext, +) -> FlinkRunnerContext: + ctx = FlinkRunnerContext.__new__(FlinkRunnerContext) + ctx._j_runner_context = j_runner_context + ctx.executor = ThreadPoolExecutor(max_workers=1) + ctx._FlinkRunnerContext__agent_plan = None + ctx._FlinkRunnerContext__ltm = None + return ctx + + +def _close_runner_context(ctx: FlinkRunnerContext) -> None: + ctx.executor.shutdown(wait=True) + + +def _run_async(result) -> object: + async def _await_result(): + return await result + + return asyncio.run(_await_result()) + + +def _preload_pending(j_runner_context: _FakeJavaRunnerContext, func, *args, **kwargs) -> None: + j_runner_context.call_results.append( + _StoredCallResult( + function_id=_compute_function_id(func), + args_digest=_compute_args_digest(args, kwargs), + status="PENDING", + ) + ) + + +def _call_value(value: str) -> str: + return f"call:{value}" + + +def test_flink_runner_context_sync_with_reconcile_callable_executes_original_call() -> None: + j_runner_context = _FakeJavaRunnerContext() + ctx = _create_runner_context(j_runner_context) + reconcile_called = False + + def reconcile() -> str: + nonlocal reconcile_called + reconcile_called = True + return "reconciled:order-1" + + try: + result = ctx.durable_execute(_call_value, "order-1", reconcile=reconcile) + finally: + _close_runner_context(ctx) + + assert result == "call:order-1" + assert reconcile_called is False + assert j_runner_context.operations == ["peek", "append_pending", "finalize"] + assert j_runner_context.call_results[0].status == "SUCCEEDED" + + +def test_flink_runner_context_sync_reconcile_success() -> None: + j_runner_context = _FakeJavaRunnerContext() + call_count = 0 + + def tracked_call(value: str) -> str: + nonlocal call_count + call_count += 1 + return _call_value(value) + + _preload_pending(j_runner_context, tracked_call, "order-1") + ctx = _create_runner_context(j_runner_context) + + try: + result = ctx.durable_execute( + tracked_call, + "order-1", + reconcile=lambda: "reconciled:order-1", + ) + finally: + _close_runner_context(ctx) + + assert result == "reconciled:order-1" + assert call_count == 0 + assert j_runner_context.operations == ["peek", "finalize"] + assert cloudpickle.loads(j_runner_context.call_results[0].result_payload) == ( + "reconciled:order-1" + ) + + +def test_flink_runner_context_sync_reconcile_failure() -> None: + j_runner_context = _FakeJavaRunnerContext() + call_count = 0 + + def tracked_call(value: str) -> str: + nonlocal call_count + call_count += 1 + return _call_value(value) + + _preload_pending(j_runner_context, tracked_call, "order-1") + ctx = _create_runner_context(j_runner_context) + + def reconcile() -> str: + raise ValueError("failed:order-1") + + try: + with pytest.raises(ValueError, match="failed:order-1"): + ctx.durable_execute(tracked_call, "order-1", reconcile=reconcile) + finally: + _close_runner_context(ctx) + + assert call_count == 0 + assert j_runner_context.operations == ["peek", "finalize"] + assert j_runner_context.call_results[0].status == "FAILED" + + +def test_flink_runner_context_sync_reconcile_fallback_executes_original_call() -> None: + j_runner_context = _FakeJavaRunnerContext() + call_count = 0 + + def tracked_call(value: str) -> str: + nonlocal call_count + call_count += 1 + return _call_value(value) + + _preload_pending(j_runner_context, tracked_call, "order-1") + ctx = _create_runner_context(j_runner_context) + + def reconcile() -> str: + raise ReconcileFallbackException(RuntimeError("reconcile unavailable")) + + try: + result = ctx.durable_execute(tracked_call, "order-1", reconcile=reconcile) + finally: + _close_runner_context(ctx) + + assert result == "call:order-1" + assert call_count == 1 + assert j_runner_context.operations == ["peek", "finalize"] + assert j_runner_context.call_results[0].status == "SUCCEEDED" + + +def test_flink_runner_context_async_writes_pending_on_await() -> None: + j_runner_context = _FakeJavaRunnerContext() + ctx = _create_runner_context(j_runner_context) + reconcile_called = False + + def reconcile() -> str: + nonlocal reconcile_called + reconcile_called = True + return "reconciled:order-1" + + try: + async_result = ctx.durable_execute_async( + _call_value, + "order-1", + reconcile=reconcile, + ) + assert j_runner_context.call_results == [] + result = _run_async(async_result) + finally: + _close_runner_context(ctx) + + assert result == "call:order-1" + assert reconcile_called is False + assert j_runner_context.operations == ["peek", "append_pending", "finalize"] + assert j_runner_context.call_results[0].status == "SUCCEEDED" + + +def test_flink_runner_context_async_reconcile_success() -> None: + j_runner_context = _FakeJavaRunnerContext() + call_count = 0 + + def tracked_call(value: str) -> str: + nonlocal call_count + call_count += 1 + return _call_value(value) + + _preload_pending(j_runner_context, tracked_call, "order-1") + ctx = _create_runner_context(j_runner_context) + + try: + async_result = ctx.durable_execute_async( + tracked_call, + "order-1", + reconcile=lambda: "reconciled:order-1", + ) + result = _run_async(async_result) + finally: + _close_runner_context(ctx) + + assert result == "reconciled:order-1" + assert call_count == 0 + assert j_runner_context.operations == ["peek", "finalize"] + + +def test_flink_runner_context_async_reconcile_fallback_executes_original_call() -> None: + j_runner_context = _FakeJavaRunnerContext() + call_count = 0 + + def tracked_call(value: str) -> str: + nonlocal call_count + call_count += 1 + return _call_value(value) + + _preload_pending(j_runner_context, tracked_call, "order-1") + ctx = _create_runner_context(j_runner_context) + + def reconcile() -> str: + raise ReconcileFallbackException(RuntimeError("reconcile unavailable")) + + try: + async_result = ctx.durable_execute_async( + tracked_call, + "order-1", + reconcile=reconcile, + ) + result = _run_async(async_result) + finally: + _close_runner_context(ctx) + + assert result == "call:order-1" + assert call_count == 1 + assert j_runner_context.operations == ["peek", "finalize"] + + +def test_flink_runner_context_reconcile_kwarg_is_not_forwarded() -> None: + j_runner_context = _FakeJavaRunnerContext() + ctx = _create_runner_context(j_runner_context) + + def collect_kwargs(**kwargs): + return kwargs + + try: + result = ctx.durable_execute(collect_kwargs, reconcile=lambda: "unused") + finally: + _close_runner_context(ctx) + + assert result == {} diff --git a/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py b/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py new file mode 100644 index 000000000..fbc993ce7 --- /dev/null +++ b/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py @@ -0,0 +1,72 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import asyncio + +from flink_agents.runtime.local_runner import LocalRunnerContext + + +def reconciled_add(x: int, y: int) -> int: + return x + y + + +def _create_local_runner_context() -> LocalRunnerContext: + return LocalRunnerContext.__new__(LocalRunnerContext) + + +def test_local_runner_context_reconcile_durable_execute_degrades() -> None: + ctx = _create_local_runner_context() + reconcile_called = False + + def reconcile() -> int: + nonlocal reconcile_called + reconcile_called = True + return 999 + + result = ctx.durable_execute(reconciled_add, 5, 10, reconcile=reconcile) + + assert result == 15 + assert reconcile_called is False + + +def test_local_runner_context_reconcile_durable_execute_async_degrades() -> None: + ctx = _create_local_runner_context() + reconcile_called = False + + def reconcile() -> int: + nonlocal reconcile_called + reconcile_called = True + return 999 + + async_result = ctx.durable_execute_async(reconciled_add, 5, 10, reconcile=reconcile) + + async def _await_result(): + return await async_result + + assert asyncio.run(_await_result()) == 15 + assert reconcile_called is False + + +def test_local_runner_context_reconcile_kwarg_is_not_forwarded() -> None: + ctx = _create_local_runner_context() + + def collect_kwargs(**kwargs): + return kwargs + + result = ctx.durable_execute(collect_kwargs, reconcile=lambda: "unused") + + assert result == {} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index f4d28d3b0..41f2f39bc 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -441,6 +441,24 @@ public void clearCallResultsFromCurrentIndexAndPersist() { } } + /** + * Returns the current durable call result as an array of fields for bridge consumers, or null + * if no persisted slot exists at the current call index. + */ + public Object[] getCurrentCallResultFields() { + CallResult current = getCurrentCallResult(); + if (current == null) { + return null; + } + return new Object[] { + current.getFunctionId(), + current.getArgsDigest(), + current.getStatus().name(), + current.getResultPayload(), + current.getExceptionPayload() + }; + } + protected CallResult getCurrentCallResult() { mailboxThreadChecker.run(); if (durableExecutionContext != null) { From 2ecd2e9211506d41a4aadf9f6d7dfc9865e923a0 Mon Sep 17 00:00:00 2001 From: joeyutong Date: Fri, 3 Apr 2026 15:03:39 +0800 Subject: [PATCH 2/6] Align Python durable reconciler with success-only semantics --- python/flink_agents/api/runner_context.py | 60 +++++------- .../test_reconcile_fallback_exception.py | 34 ------- .../flink_agents/runtime/durable_execution.py | 18 ++-- .../runtime/flink_runner_context.py | 84 ++++++----------- python/flink_agents/runtime/local_runner.py | 4 +- .../runtime/tests/test_durable_execution.py | 38 ++++---- .../test_flink_runner_context_reconcilable.py | 94 +++++++------------ .../tests/test_local_runner_reconcilable.py | 34 +++---- .../runtime/context/RunnerContextImpl.java | 2 +- 9 files changed, 136 insertions(+), 232 deletions(-) delete mode 100644 python/flink_agents/api/tests/test_reconcile_fallback_exception.py diff --git a/python/flink_agents/api/runner_context.py b/python/flink_agents/api/runner_context.py index a58a5e403..6bb883315 100644 --- a/python/flink_agents/api/runner_context.py +++ b/python/flink_agents/api/runner_context.py @@ -28,22 +28,6 @@ from flink_agents.api.memory_object import MemoryObject -class ReconcileFallbackException(Exception): - """Signal that reconcile could not recover a terminal outcome. - - Raising this exception from a reconcile callable tells the runtime to fall back - to the original durable call instead of treating the reconcile failure as a - terminal outcome. - """ - - def __init__(self, cause: BaseException) -> None: - if not isinstance(cause, BaseException): - err_msg = "ReconcileFallbackException requires a BaseException cause" - raise TypeError(err_msg) - super().__init__(str(cause)) - self.cause = cause - - class AsyncExecutionResult: """This class wraps an asynchronous task that will be submitted to a thread pool only when awaited. This ensures lazy submission and serial execution semantics. @@ -212,7 +196,7 @@ def durable_execute( self, func: Callable[[Any], Any], *args: Any, - reconcile: Callable[[], Any] | None = None, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> Any: """Synchronously execute the provided function with durable execution support. @@ -229,15 +213,15 @@ def durable_execute( will always make the durable_execute call with the same arguments and in the same order during job recovery. Otherwise, the behavior is undefined. - If `reconcile` is provided, it is used only when recovery resumes from a - `PENDING` durable call result. The reconcile callable should: + If `reconciler` is provided, it is used only during recovery when the + previous durable invocation has not yet produced a persisted terminal + outcome. The reconciler may: - * return the recovered terminal result when the previous invocation - succeeded - * raise the recovered terminal business exception when the previous - invocation failed - * raise `ReconcileFallbackException` when it cannot determine a terminal - outcome and the runtime should execute `func` + * return a result to provide the recovered successful outcome for this + durable call + * raise an exception if it cannot provide a successful outcome; the + exception is propagated to the caller and no recovered terminal + outcome is persisted for this durable call Usage:: @@ -251,8 +235,8 @@ def my_action(event, ctx): The function to be executed. *args : Any Positional arguments to pass to the function. - reconcile : Callable[[], Any] | None - Optional zero-argument reconcile callable used only during recovery. + reconciler : Callable[[], Any] | None + Optional zero-argument reconciler callable used only during recovery. This is a reserved keyword-only parameter and is not forwarded to `func`. **kwargs : Any @@ -269,7 +253,7 @@ def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, - reconcile: Callable[[], Any] | None = None, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> "AsyncExecutionResult": """Asynchronously execute the provided function with durable execution support. @@ -283,15 +267,15 @@ def durable_execute_async( will always make the durable_execute_async call with the same arguments and in the same order during job recovery. Otherwise, the behavior is undefined. - If `reconcile` is provided, it is used only when recovery resumes from a - `PENDING` durable call result. The reconcile callable should: + If `reconciler` is provided, it is used only during recovery when the + previous durable invocation has not yet produced a persisted terminal + outcome. The reconciler may: - * return the recovered terminal result when the previous invocation - succeeded - * raise the recovered terminal business exception when the previous - invocation failed - * raise `ReconcileFallbackException` when it cannot determine a terminal - outcome and the runtime should execute `func` + * return a result to provide the recovered successful outcome for this + durable call + * raise an exception if it cannot provide a successful outcome; the + exception is propagated to the caller and no recovered terminal + outcome is persisted for this durable call Usage:: @@ -309,8 +293,8 @@ async def my_action(event, ctx): The function to be executed asynchronously. *args : Any Positional arguments to pass to the function. - reconcile : Callable[[], Any] | None - Optional zero-argument reconcile callable used only during recovery. + reconciler : Callable[[], Any] | None + Optional zero-argument reconciler callable used only during recovery. This is a reserved keyword-only parameter and is not forwarded to `func`. **kwargs : Any diff --git a/python/flink_agents/api/tests/test_reconcile_fallback_exception.py b/python/flink_agents/api/tests/test_reconcile_fallback_exception.py deleted file mode 100644 index 3948a56a5..000000000 --- a/python/flink_agents/api/tests/test_reconcile_fallback_exception.py +++ /dev/null @@ -1,34 +0,0 @@ -################################################################################ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -################################################################################# -import pytest - -from flink_agents.api.runner_context import ReconcileFallbackException - - -def test_reconcile_fallback_exception_preserves_cause() -> None: - cause = ValueError("indeterminate") - - exc = ReconcileFallbackException(cause) - - assert exc.cause is cause - assert str(exc) == "indeterminate" - - -def test_reconcile_fallback_exception_requires_exception_cause() -> None: - with pytest.raises(TypeError, match="requires a BaseException cause"): - ReconcileFallbackException("not-an-exception") # type: ignore[arg-type] diff --git a/python/flink_agents/runtime/durable_execution.py b/python/flink_agents/runtime/durable_execution.py index d4ac9c1f3..9577fa197 100644 --- a/python/flink_agents/runtime/durable_execution.py +++ b/python/flink_agents/runtime/durable_execution.py @@ -56,19 +56,19 @@ def _can_bind_call( return False -def _validate_reconcile_callable( - reconcile: Callable[[], Any] | None, +def _validate_reconciler_callable( + reconciler: Callable[[], Any] | None, ) -> Callable[[], Any] | None: - """Validate that the reconcile callable is either absent or zero-argument.""" - if reconcile is None: + """Validate that the reconciler callable is either absent or zero-argument.""" + if reconciler is None: return None - if not callable(reconcile): - err_msg = "reconcile must be callable" + if not callable(reconciler): + err_msg = "reconciler must be callable" raise TypeError(err_msg) - if not _can_bind_call(reconcile): - err_msg = "reconcile must be a callable that takes no arguments" + if not _can_bind_call(reconciler): + err_msg = "reconciler must be a callable that takes no arguments" raise TypeError(err_msg) - return reconcile + return reconciler diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index 20dcfda3f..771ce6130 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -36,7 +36,6 @@ from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.runner_context import ( AsyncExecutionResult, - ReconcileFallbackException, RunnerContext, ) from flink_agents.runtime.flink_memory_object import FlinkMemoryObject @@ -44,7 +43,7 @@ from flink_agents.runtime.durable_execution import ( _compute_args_digest, _compute_function_id, - _validate_reconcile_callable, + _validate_reconciler_callable, ) from flink_agents.runtime.memory.internal_base_long_term_memory import ( InternalBaseLongTermMemory, @@ -71,7 +70,7 @@ class _PersistedCallResult: @dataclass(frozen=True) -class _ReconcileCallResolution: +class _ReconcilerCallResolution: kind: str result: Any = None exception: BaseException | None = None @@ -182,8 +181,8 @@ def __await__(self) -> Any: return result -class _ReconcileDurableAsyncExecutionResult(AsyncExecutionResult): - """An AsyncExecutionResult that resolves reconcile state on await.""" +class _ReconcilerDurableAsyncExecutionResult(AsyncExecutionResult): + """An AsyncExecutionResult that resolves reconciler state on await.""" def __init__( self, @@ -191,18 +190,18 @@ def __init__( executor: Any, func: Callable, args: tuple, - reconcile: Callable[[], Any], + reconciler: Callable[[], Any], kwargs: dict, ) -> None: super().__init__(executor, func, args, kwargs) self._ctx = ctx - self._reconcile = reconcile + self._reconciler = reconciler def __await__(self) -> Any: - resolution = self._ctx._resolve_reconcile_call( + resolution = self._ctx._resolve_reconciler_call( self._func, self._args, - self._reconcile, + self._reconciler, self._kwargs, ) @@ -224,14 +223,7 @@ def __await__(self) -> Any: yield return resolution.result - if resolution.kind == "finalize_failure": - self._ctx._finalize_current_call( - self._func, - self._args, - self._kwargs, - None, - resolution.exception, - ) + if resolution.kind == "raise": raise resolution.exception future = self._executor.submit(self._func, *self._args, **self._kwargs) @@ -552,47 +544,38 @@ def _replay_terminal_call(self, func: Callable, args: tuple, kwargs: dict) -> An raise RuntimeError(err_msg) return cached_result - def _resolve_reconcile_call( + def _resolve_reconciler_call( self, func: Callable, args: tuple, - reconcile: Callable[[], Any], + reconciler: Callable[[], Any], kwargs: dict, - ) -> _ReconcileCallResolution: + ) -> _ReconcilerCallResolution: function_id = _compute_function_id(func) args_digest = _compute_args_digest(args, kwargs) current = self._peek_current_call_result() if current is None: self._append_pending_call(func, args, kwargs) - return _ReconcileCallResolution("execute") + return _ReconcilerCallResolution("execute") if current.function_id != function_id or current.args_digest != args_digest: self._clear_call_results_from_current_index_and_persist() self._append_pending_call(func, args, kwargs) - return _ReconcileCallResolution("execute") + return _ReconcilerCallResolution("execute") if current.status != "PENDING": - return _ReconcileCallResolution("replay") + return _ReconcilerCallResolution("replay") try: - result = reconcile() - except ReconcileFallbackException: - logger.warning( - "Reconcile fell back for Python durable call function_id=%s, args_digest=%s. " - "Falling back to re-execution.", - function_id, - args_digest, - exc_info=True, - ) - return _ReconcileCallResolution("execute") + result = reconciler() except Exception as exception: - return _ReconcileCallResolution( - "finalize_failure", + return _ReconcilerCallResolution( + "raise", exception=exception, ) - return _ReconcileCallResolution("finalize_success", result=result) + return _ReconcilerCallResolution("finalize_success", result=result) def _execute_current_pending_call( self, @@ -642,7 +625,7 @@ def durable_execute( self, func: Callable[[Any], Any], *args: Any, - reconcile: Callable[[], Any] | None = None, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> Any: """Synchronously execute the provided function with durable execution support. @@ -655,13 +638,13 @@ def durable_execute( The function is executed synchronously in the current thread, blocking the operator until completion. """ - validated_reconcile = _validate_reconcile_callable(reconcile) + validated_reconciler = _validate_reconciler_callable(reconciler) - if validated_reconcile is not None: - resolution = self._resolve_reconcile_call( + if validated_reconciler is not None: + resolution = self._resolve_reconciler_call( func, args, - validated_reconcile, + validated_reconciler, kwargs, ) if resolution.kind == "replay": @@ -669,14 +652,7 @@ def durable_execute( if resolution.kind == "finalize_success": self._finalize_current_call(func, args, kwargs, resolution.result, None) return resolution.result - if resolution.kind == "finalize_failure": - self._finalize_current_call( - func, - args, - kwargs, - None, - resolution.exception, - ) + if resolution.kind == "raise": raise resolution.exception return self._execute_current_pending_call(func, args, kwargs) @@ -705,7 +681,7 @@ def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, - reconcile: Callable[[], Any] | None = None, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> AsyncExecutionResult: """Asynchronously execute the provided function with durable execution support. @@ -719,15 +695,15 @@ def durable_execute_async( is awaited. Fire-and-forget calls (not awaiting the result) will NOT be recorded and cannot be recovered. """ - validated_reconcile = _validate_reconcile_callable(reconcile) + validated_reconciler = _validate_reconciler_callable(reconciler) - if validated_reconcile is not None: - return _ReconcileDurableAsyncExecutionResult( + if validated_reconciler is not None: + return _ReconcilerDurableAsyncExecutionResult( self, self.executor, func, args, - validated_reconcile, + validated_reconciler, kwargs, ) diff --git a/python/flink_agents/runtime/local_runner.py b/python/flink_agents/runtime/local_runner.py index 8f7e1ba65..53823b61d 100644 --- a/python/flink_agents/runtime/local_runner.py +++ b/python/flink_agents/runtime/local_runner.py @@ -192,7 +192,7 @@ def durable_execute( self, func: Callable[[Any], Any], *args: Any, - reconcile: Callable[[], Any] | None = None, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> Any: """Synchronously execute the provided function. Access to memory @@ -211,7 +211,7 @@ def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, - reconcile: Callable[[], Any] | None = None, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> AsyncExecutionResult: """Asynchronously execute the provided function. Access to memory diff --git a/python/flink_agents/runtime/tests/test_durable_execution.py b/python/flink_agents/runtime/tests/test_durable_execution.py index b09d0b184..920175299 100644 --- a/python/flink_agents/runtime/tests/test_durable_execution.py +++ b/python/flink_agents/runtime/tests/test_durable_execution.py @@ -23,7 +23,7 @@ from flink_agents.runtime.durable_execution import ( _compute_args_digest, _compute_function_id, - _validate_reconcile_callable, + _validate_reconciler_callable, ) @@ -50,8 +50,8 @@ def class_method(cls, x: int) -> int: return x * 4 -class ReconcileCallables: - """Helpers for reconcile callable validation tests.""" +class ReconcilerCallables: + """Helpers for reconciler callable validation tests.""" def __init__(self, prefix: str) -> None: self.prefix = prefix @@ -142,42 +142,42 @@ def test_compute_args_digest_kwargs_vs_args() -> None: assert digest1 != digest2 -def test_validate_reconcile_callable_accepts_none() -> None: - assert _validate_reconcile_callable(None) is None +def test_validate_reconciler_callable_accepts_none() -> None: + assert _validate_reconciler_callable(None) is None -def test_validate_reconcile_callable_accepts_zero_arg_function() -> None: - def reconcile() -> str: +def test_validate_reconciler_callable_accepts_zero_arg_function() -> None: + def reconciler() -> str: return "ok" - validated = _validate_reconcile_callable(reconcile) + validated = _validate_reconciler_callable(reconciler) - assert validated is reconcile + assert validated is reconciler assert validated() == "ok" -def test_validate_reconcile_callable_accepts_bound_zero_arg_method() -> None: - callables = ReconcileCallables("client") +def test_validate_reconciler_callable_accepts_bound_zero_arg_method() -> None: + callables = ReconcilerCallables("client") bound_method = callables.bound_no_arg - validated = _validate_reconcile_callable(bound_method) + validated = _validate_reconciler_callable(bound_method) assert validated is bound_method assert validated() == "bound:client" -def test_validate_reconcile_callable_requires_callable() -> None: - with pytest.raises(TypeError, match="reconcile must be callable"): - _validate_reconcile_callable(1) # type: ignore[arg-type] +def test_validate_reconciler_callable_requires_callable() -> None: + with pytest.raises(TypeError, match="reconciler must be callable"): + _validate_reconciler_callable(1) # type: ignore[arg-type] -def test_validate_reconcile_callable_requires_zero_args() -> None: - callables = ReconcileCallables("client") +def test_validate_reconciler_callable_requires_zero_args() -> None: + callables = ReconcilerCallables("client") with pytest.raises( - TypeError, match="reconcile must be a callable that takes no arguments" + TypeError, match="reconciler must be a callable that takes no arguments" ): - _validate_reconcile_callable(callables.requires_arg) + _validate_reconciler_callable(callables.requires_arg) def test_cloudpickle_serialization() -> None: diff --git a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py index 8afaa1320..3ff30bb57 100644 --- a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py +++ b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py @@ -22,7 +22,6 @@ import cloudpickle import pytest -from flink_agents.api.runner_context import ReconcileFallbackException from flink_agents.runtime.durable_execution import ( _compute_args_digest, _compute_function_id, @@ -163,28 +162,28 @@ def _call_value(value: str) -> str: return f"call:{value}" -def test_flink_runner_context_sync_with_reconcile_callable_executes_original_call() -> None: +def test_flink_runner_context_sync_with_reconciler_executes_original_call() -> None: j_runner_context = _FakeJavaRunnerContext() ctx = _create_runner_context(j_runner_context) - reconcile_called = False + reconciler_called = False - def reconcile() -> str: - nonlocal reconcile_called - reconcile_called = True + def reconciler() -> str: + nonlocal reconciler_called + reconciler_called = True return "reconciled:order-1" try: - result = ctx.durable_execute(_call_value, "order-1", reconcile=reconcile) + result = ctx.durable_execute(_call_value, "order-1", reconciler=reconciler) finally: _close_runner_context(ctx) assert result == "call:order-1" - assert reconcile_called is False + assert reconciler_called is False assert j_runner_context.operations == ["peek", "append_pending", "finalize"] assert j_runner_context.call_results[0].status == "SUCCEEDED" -def test_flink_runner_context_sync_reconcile_success() -> None: +def test_flink_runner_context_sync_reconciler_success() -> None: j_runner_context = _FakeJavaRunnerContext() call_count = 0 @@ -200,7 +199,7 @@ def tracked_call(value: str) -> str: result = ctx.durable_execute( tracked_call, "order-1", - reconcile=lambda: "reconciled:order-1", + reconciler=lambda: "reconciled:order-1", ) finally: _close_runner_context(ctx) @@ -213,7 +212,7 @@ def tracked_call(value: str) -> str: ) -def test_flink_runner_context_sync_reconcile_failure() -> None: +def test_flink_runner_context_sync_reconciler_exception_propagates() -> None: j_runner_context = _FakeJavaRunnerContext() call_count = 0 @@ -225,61 +224,36 @@ def tracked_call(value: str) -> str: _preload_pending(j_runner_context, tracked_call, "order-1") ctx = _create_runner_context(j_runner_context) - def reconcile() -> str: + def reconciler() -> str: raise ValueError("failed:order-1") try: with pytest.raises(ValueError, match="failed:order-1"): - ctx.durable_execute(tracked_call, "order-1", reconcile=reconcile) + ctx.durable_execute(tracked_call, "order-1", reconciler=reconciler) finally: _close_runner_context(ctx) assert call_count == 0 - assert j_runner_context.operations == ["peek", "finalize"] - assert j_runner_context.call_results[0].status == "FAILED" - - -def test_flink_runner_context_sync_reconcile_fallback_executes_original_call() -> None: - j_runner_context = _FakeJavaRunnerContext() - call_count = 0 - - def tracked_call(value: str) -> str: - nonlocal call_count - call_count += 1 - return _call_value(value) - - _preload_pending(j_runner_context, tracked_call, "order-1") - ctx = _create_runner_context(j_runner_context) - - def reconcile() -> str: - raise ReconcileFallbackException(RuntimeError("reconcile unavailable")) - - try: - result = ctx.durable_execute(tracked_call, "order-1", reconcile=reconcile) - finally: - _close_runner_context(ctx) - - assert result == "call:order-1" - assert call_count == 1 - assert j_runner_context.operations == ["peek", "finalize"] - assert j_runner_context.call_results[0].status == "SUCCEEDED" + assert j_runner_context.operations == ["peek"] + assert j_runner_context.call_results[0].status == "PENDING" + assert j_runner_context.current_call_index == 0 def test_flink_runner_context_async_writes_pending_on_await() -> None: j_runner_context = _FakeJavaRunnerContext() ctx = _create_runner_context(j_runner_context) - reconcile_called = False + reconciler_called = False - def reconcile() -> str: - nonlocal reconcile_called - reconcile_called = True + def reconciler() -> str: + nonlocal reconciler_called + reconciler_called = True return "reconciled:order-1" try: async_result = ctx.durable_execute_async( _call_value, "order-1", - reconcile=reconcile, + reconciler=reconciler, ) assert j_runner_context.call_results == [] result = _run_async(async_result) @@ -287,12 +261,12 @@ def reconcile() -> str: _close_runner_context(ctx) assert result == "call:order-1" - assert reconcile_called is False + assert reconciler_called is False assert j_runner_context.operations == ["peek", "append_pending", "finalize"] assert j_runner_context.call_results[0].status == "SUCCEEDED" -def test_flink_runner_context_async_reconcile_success() -> None: +def test_flink_runner_context_async_reconciler_success() -> None: j_runner_context = _FakeJavaRunnerContext() call_count = 0 @@ -308,7 +282,7 @@ def tracked_call(value: str) -> str: async_result = ctx.durable_execute_async( tracked_call, "order-1", - reconcile=lambda: "reconciled:order-1", + reconciler=lambda: "reconciled:order-1", ) result = _run_async(async_result) finally: @@ -319,7 +293,7 @@ def tracked_call(value: str) -> str: assert j_runner_context.operations == ["peek", "finalize"] -def test_flink_runner_context_async_reconcile_fallback_executes_original_call() -> None: +def test_flink_runner_context_async_reconciler_exception_propagates() -> None: j_runner_context = _FakeJavaRunnerContext() call_count = 0 @@ -331,25 +305,27 @@ def tracked_call(value: str) -> str: _preload_pending(j_runner_context, tracked_call, "order-1") ctx = _create_runner_context(j_runner_context) - def reconcile() -> str: - raise ReconcileFallbackException(RuntimeError("reconcile unavailable")) + def reconciler() -> str: + raise RuntimeError("reconcile unavailable") try: async_result = ctx.durable_execute_async( tracked_call, "order-1", - reconcile=reconcile, + reconciler=reconciler, ) - result = _run_async(async_result) + with pytest.raises(RuntimeError, match="reconcile unavailable"): + _run_async(async_result) finally: _close_runner_context(ctx) - assert result == "call:order-1" - assert call_count == 1 - assert j_runner_context.operations == ["peek", "finalize"] + assert call_count == 0 + assert j_runner_context.operations == ["peek"] + assert j_runner_context.call_results[0].status == "PENDING" + assert j_runner_context.current_call_index == 0 -def test_flink_runner_context_reconcile_kwarg_is_not_forwarded() -> None: +def test_flink_runner_context_reconciler_kwarg_is_not_forwarded() -> None: j_runner_context = _FakeJavaRunnerContext() ctx = _create_runner_context(j_runner_context) @@ -357,7 +333,7 @@ def collect_kwargs(**kwargs): return kwargs try: - result = ctx.durable_execute(collect_kwargs, reconcile=lambda: "unused") + result = ctx.durable_execute(collect_kwargs, reconciler=lambda: "unused") finally: _close_runner_context(ctx) diff --git a/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py b/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py index fbc993ce7..21e8cd4cc 100644 --- a/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py +++ b/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py @@ -28,45 +28,47 @@ def _create_local_runner_context() -> LocalRunnerContext: return LocalRunnerContext.__new__(LocalRunnerContext) -def test_local_runner_context_reconcile_durable_execute_degrades() -> None: +def test_local_runner_context_reconciler_durable_execute_degrades() -> None: ctx = _create_local_runner_context() - reconcile_called = False + reconciler_called = False - def reconcile() -> int: - nonlocal reconcile_called - reconcile_called = True + def reconciler() -> int: + nonlocal reconciler_called + reconciler_called = True return 999 - result = ctx.durable_execute(reconciled_add, 5, 10, reconcile=reconcile) + result = ctx.durable_execute(reconciled_add, 5, 10, reconciler=reconciler) assert result == 15 - assert reconcile_called is False + assert reconciler_called is False -def test_local_runner_context_reconcile_durable_execute_async_degrades() -> None: +def test_local_runner_context_reconciler_durable_execute_async_degrades() -> None: ctx = _create_local_runner_context() - reconcile_called = False + reconciler_called = False - def reconcile() -> int: - nonlocal reconcile_called - reconcile_called = True + def reconciler() -> int: + nonlocal reconciler_called + reconciler_called = True return 999 - async_result = ctx.durable_execute_async(reconciled_add, 5, 10, reconcile=reconcile) + async_result = ctx.durable_execute_async( + reconciled_add, 5, 10, reconciler=reconciler + ) async def _await_result(): return await async_result assert asyncio.run(_await_result()) == 15 - assert reconcile_called is False + assert reconciler_called is False -def test_local_runner_context_reconcile_kwarg_is_not_forwarded() -> None: +def test_local_runner_context_reconciler_kwarg_is_not_forwarded() -> None: ctx = _create_local_runner_context() def collect_kwargs(**kwargs): return kwargs - result = ctx.durable_execute(collect_kwargs, reconcile=lambda: "unused") + result = ctx.durable_execute(collect_kwargs, reconciler=lambda: "unused") assert result == {} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index 41f2f39bc..3e302854a 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -453,7 +453,7 @@ public Object[] getCurrentCallResultFields() { return new Object[] { current.getFunctionId(), current.getArgsDigest(), - current.getStatus().name(), + current.isPending() ? "PENDING" : current.isFailure() ? "FAILED" : "SUCCEEDED", current.getResultPayload(), current.getExceptionPayload() }; From 32145ef5f2890efc315f0465d9bc6fe6dd44448c Mon Sep 17 00:00:00 2001 From: joeyutong Date: Wed, 8 Apr 2026 10:32:56 +0800 Subject: [PATCH 3/6] Fix Python code style issues --- .../flink_agents/runtime/durable_execution.py | 3 +- .../runtime/flink_runner_context.py | 9 ++--- .../runtime/tests/test_durable_execution.py | 8 +++++ .../test_flink_runner_context_reconcilable.py | 33 ++++++++++++++----- .../tests/test_local_runner_reconcilable.py | 9 +++-- 5 files changed, 45 insertions(+), 17 deletions(-) diff --git a/python/flink_agents/runtime/durable_execution.py b/python/flink_agents/runtime/durable_execution.py index 9577fa197..9db22bfa9 100644 --- a/python/flink_agents/runtime/durable_execution.py +++ b/python/flink_agents/runtime/durable_execution.py @@ -51,9 +51,10 @@ def _can_bind_call( """Return whether the callable signature can bind the provided arguments.""" try: inspect.signature(func).bind(*args, **kwargs) - return True except (TypeError, ValueError): return False + else: + return True def _validate_reconciler_callable( diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index 771ce6130..2f085872c 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -19,7 +19,7 @@ import os from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict +from typing import Any, Callable, Dict import cloudpickle from typing_extensions import override @@ -38,13 +38,13 @@ AsyncExecutionResult, RunnerContext, ) -from flink_agents.runtime.flink_memory_object import FlinkMemoryObject -from flink_agents.runtime.flink_metric_group import FlinkMetricGroup from flink_agents.runtime.durable_execution import ( _compute_args_digest, _compute_function_id, _validate_reconciler_callable, ) +from flink_agents.runtime.flink_memory_object import FlinkMemoryObject +from flink_agents.runtime.flink_metric_group import FlinkMetricGroup from flink_agents.runtime.memory.internal_base_long_term_memory import ( InternalBaseLongTermMemory, ) @@ -54,9 +54,6 @@ from flink_agents.runtime.python_java_utils import _build_event_log_string from flink_agents.runtime.resource_cache import ResourceCache -if TYPE_CHECKING: - from flink_agents.plan.agent_plan import AgentPlan - logger = logging.getLogger(__name__) diff --git a/python/flink_agents/runtime/tests/test_durable_execution.py b/python/flink_agents/runtime/tests/test_durable_execution.py index 920175299..52d777a2c 100644 --- a/python/flink_agents/runtime/tests/test_durable_execution.py +++ b/python/flink_agents/runtime/tests/test_durable_execution.py @@ -54,12 +54,15 @@ class ReconcilerCallables: """Helpers for reconciler callable validation tests.""" def __init__(self, prefix: str) -> None: + """Store a prefix used by the helper callables.""" self.prefix = prefix def bound_no_arg(self) -> str: + """Return a bound zero-argument reconciler result.""" return f"bound:{self.prefix}" def requires_arg(self, value: int) -> str: + """Require an argument so validation can reject the callable.""" return f"{self.prefix}:{value}" @@ -143,10 +146,12 @@ def test_compute_args_digest_kwargs_vs_args() -> None: def test_validate_reconciler_callable_accepts_none() -> None: + """Allow omitting the reconciler callable.""" assert _validate_reconciler_callable(None) is None def test_validate_reconciler_callable_accepts_zero_arg_function() -> None: + """Accept a zero-argument reconciler function.""" def reconciler() -> str: return "ok" @@ -157,6 +162,7 @@ def reconciler() -> str: def test_validate_reconciler_callable_accepts_bound_zero_arg_method() -> None: + """Accept a bound reconciler method with no remaining arguments.""" callables = ReconcilerCallables("client") bound_method = callables.bound_no_arg @@ -167,11 +173,13 @@ def test_validate_reconciler_callable_accepts_bound_zero_arg_method() -> None: def test_validate_reconciler_callable_requires_callable() -> None: + """Reject non-callable reconciler values.""" with pytest.raises(TypeError, match="reconciler must be callable"): _validate_reconciler_callable(1) # type: ignore[arg-type] def test_validate_reconciler_callable_requires_zero_args() -> None: + """Reject reconciler callables that require arguments.""" callables = ReconcilerCallables("client") with pytest.raises( diff --git a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py index 3ff30bb57..1cbf90f26 100644 --- a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py +++ b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py @@ -18,6 +18,7 @@ import asyncio from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from typing import Any, Callable import cloudpickle import pytest @@ -44,7 +45,7 @@ def __init__(self) -> None: self.current_call_index = 0 self.operations: list[str] = [] - def getCurrentCallResultFields(self): + def getCurrentCallResultFields(self) -> list[Any] | None: self.operations.append("peek") if self.current_call_index < len(self.call_results): current = self.call_results[self.current_call_index] @@ -57,7 +58,9 @@ def getCurrentCallResultFields(self): ] return None - def matchNextOrClearSubsequentCallResult(self, function_id: str, args_digest: str): + def matchNextOrClearSubsequentCallResult( + self, function_id: str, args_digest: str + ) -> list[Any] | None: self.operations.append("match") if self.current_call_index < len(self.call_results): current = self.call_results[self.current_call_index] @@ -141,14 +144,19 @@ def _close_runner_context(ctx: FlinkRunnerContext) -> None: ctx.executor.shutdown(wait=True) -def _run_async(result) -> object: - async def _await_result(): +def _run_async(result: Any) -> object: + async def _await_result() -> Any: return await result return asyncio.run(_await_result()) -def _preload_pending(j_runner_context: _FakeJavaRunnerContext, func, *args, **kwargs) -> None: +def _preload_pending( + j_runner_context: _FakeJavaRunnerContext, + func: Callable[..., Any], + *args: Any, + **kwargs: Any, +) -> None: j_runner_context.call_results.append( _StoredCallResult( function_id=_compute_function_id(func), @@ -163,6 +171,7 @@ def _call_value(value: str) -> str: def test_flink_runner_context_sync_with_reconciler_executes_original_call() -> None: + """Start a new durable call when no pending state exists.""" j_runner_context = _FakeJavaRunnerContext() ctx = _create_runner_context(j_runner_context) reconciler_called = False @@ -184,6 +193,7 @@ def reconciler() -> str: def test_flink_runner_context_sync_reconciler_success() -> None: + """Persist a recovered success without re-executing the original call.""" j_runner_context = _FakeJavaRunnerContext() call_count = 0 @@ -213,6 +223,7 @@ def tracked_call(value: str) -> str: def test_flink_runner_context_sync_reconciler_exception_propagates() -> None: + """Propagate reconciler exceptions without finalizing the pending slot.""" j_runner_context = _FakeJavaRunnerContext() call_count = 0 @@ -225,7 +236,8 @@ def tracked_call(value: str) -> str: ctx = _create_runner_context(j_runner_context) def reconciler() -> str: - raise ValueError("failed:order-1") + error_message = "failed:order-1" + raise ValueError(error_message) try: with pytest.raises(ValueError, match="failed:order-1"): @@ -240,6 +252,7 @@ def reconciler() -> str: def test_flink_runner_context_async_writes_pending_on_await() -> None: + """Defer pending-state writes for async execution until await time.""" j_runner_context = _FakeJavaRunnerContext() ctx = _create_runner_context(j_runner_context) reconciler_called = False @@ -267,6 +280,7 @@ def reconciler() -> str: def test_flink_runner_context_async_reconciler_success() -> None: + """Recover a successful async result through the reconciler.""" j_runner_context = _FakeJavaRunnerContext() call_count = 0 @@ -294,6 +308,7 @@ def tracked_call(value: str) -> str: def test_flink_runner_context_async_reconciler_exception_propagates() -> None: + """Propagate async reconciler exceptions without advancing state.""" j_runner_context = _FakeJavaRunnerContext() call_count = 0 @@ -306,7 +321,8 @@ def tracked_call(value: str) -> str: ctx = _create_runner_context(j_runner_context) def reconciler() -> str: - raise RuntimeError("reconcile unavailable") + error_message = "reconcile unavailable" + raise RuntimeError(error_message) try: async_result = ctx.durable_execute_async( @@ -326,10 +342,11 @@ def reconciler() -> str: def test_flink_runner_context_reconciler_kwarg_is_not_forwarded() -> None: + """Keep the reserved reconciler kwarg out of the user function call.""" j_runner_context = _FakeJavaRunnerContext() ctx = _create_runner_context(j_runner_context) - def collect_kwargs(**kwargs): + def collect_kwargs(**kwargs: Any) -> dict[str, Any]: return kwargs try: diff --git a/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py b/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py index 21e8cd4cc..a213423f2 100644 --- a/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py +++ b/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py @@ -16,11 +16,13 @@ # limitations under the License. ################################################################################# import asyncio +from typing import Any from flink_agents.runtime.local_runner import LocalRunnerContext def reconciled_add(x: int, y: int) -> int: + """Return a simple deterministic value for local-runner tests.""" return x + y @@ -29,6 +31,7 @@ def _create_local_runner_context() -> LocalRunnerContext: def test_local_runner_context_reconciler_durable_execute_degrades() -> None: + """Keep sync local execution on the existing non-durable path.""" ctx = _create_local_runner_context() reconciler_called = False @@ -44,6 +47,7 @@ def reconciler() -> int: def test_local_runner_context_reconciler_durable_execute_async_degrades() -> None: + """Keep async local execution on the existing non-durable path.""" ctx = _create_local_runner_context() reconciler_called = False @@ -56,7 +60,7 @@ def reconciler() -> int: reconciled_add, 5, 10, reconciler=reconciler ) - async def _await_result(): + async def _await_result() -> Any: return await async_result assert asyncio.run(_await_result()) == 15 @@ -64,9 +68,10 @@ async def _await_result(): def test_local_runner_context_reconciler_kwarg_is_not_forwarded() -> None: + """Do not forward the reserved reconciler kwarg to the user function.""" ctx = _create_local_runner_context() - def collect_kwargs(**kwargs): + def collect_kwargs(**kwargs: Any) -> dict[str, Any]: return kwargs result = ctx.durable_execute(collect_kwargs, reconciler=lambda: "unused") From ac1b7c14fbe2ca496c1ee534c64f086523efe58a Mon Sep 17 00:00:00 2001 From: joeyutong Date: Wed, 8 Apr 2026 14:05:31 +0800 Subject: [PATCH 4/6] Add mismatch recovery test for Python reconciler --- .../test_flink_runner_context_reconcilable.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py index 1cbf90f26..9335aa15a 100644 --- a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py +++ b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py @@ -251,6 +251,49 @@ def reconciler() -> str: assert j_runner_context.current_call_index == 0 +def test_flink_runner_context_sync_reconciler_mismatch_clears_and_executes() -> None: + """Clear mismatched persisted state before executing the original call.""" + j_runner_context = _FakeJavaRunnerContext() + stale_result_payload = cloudpickle.dumps("stale") + j_runner_context.call_results.extend( + [ + _StoredCallResult( + function_id=_compute_function_id(_call_value), + args_digest=_compute_args_digest(("other-order",), {}), + status="PENDING", + ), + _StoredCallResult( + function_id="stale.function", + args_digest="stale-args", + status="SUCCEEDED", + result_payload=stale_result_payload, + ), + ] + ) + ctx = _create_runner_context(j_runner_context) + reconciler_called = False + + def reconciler() -> str: + nonlocal reconciler_called + reconciler_called = True + return "reconciled:order-1" + + try: + result = ctx.durable_execute(_call_value, "order-1", reconciler=reconciler) + finally: + _close_runner_context(ctx) + + assert result == "call:order-1" + assert reconciler_called is False + assert j_runner_context.operations == ["peek", "clear", "append_pending", "finalize"] + assert len(j_runner_context.call_results) == 1 + assert j_runner_context.call_results[0].function_id == _compute_function_id(_call_value) + assert j_runner_context.call_results[0].args_digest == _compute_args_digest( + ("order-1",), {} + ) + assert j_runner_context.call_results[0].status == "SUCCEEDED" + + def test_flink_runner_context_async_writes_pending_on_await() -> None: """Defer pending-state writes for async execution until await time.""" j_runner_context = _FakeJavaRunnerContext() From 8ac985abeaff129c4823cadec4e46e978433451e Mon Sep 17 00:00:00 2001 From: joeyutong Date: Wed, 8 Apr 2026 14:13:06 +0800 Subject: [PATCH 5/6] Clarify Python reconciler docstrings --- python/flink_agents/api/runner_context.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/flink_agents/api/runner_context.py b/python/flink_agents/api/runner_context.py index 6bb883315..5f3daebdb 100644 --- a/python/flink_agents/api/runner_context.py +++ b/python/flink_agents/api/runner_context.py @@ -213,9 +213,9 @@ def durable_execute( will always make the durable_execute call with the same arguments and in the same order during job recovery. Otherwise, the behavior is undefined. - If `reconciler` is provided, it is used only during recovery when the - previous durable invocation has not yet produced a persisted terminal - outcome. The reconciler may: + If `reconciler` is provided, recovery invokes it only when revisiting + this durable call and no terminal outcome from the previous durable + invocation has been persisted yet. The reconciler may: * return a result to provide the recovered successful outcome for this durable call @@ -267,9 +267,9 @@ def durable_execute_async( will always make the durable_execute_async call with the same arguments and in the same order during job recovery. Otherwise, the behavior is undefined. - If `reconciler` is provided, it is used only during recovery when the - previous durable invocation has not yet produced a persisted terminal - outcome. The reconciler may: + If `reconciler` is provided, recovery invokes it only when revisiting + this durable call and no terminal outcome from the previous durable + invocation has been persisted yet. The reconciler may: * return a result to provide the recovered successful outcome for this durable call From f85bd8fbe0a14b9255f3293df877f830d225beec Mon Sep 17 00:00:00 2001 From: joeyutong Date: Mon, 13 Apr 2026 12:56:41 +0800 Subject: [PATCH 6/6] Refine Python reconciler recovery flow --- python/flink_agents/api/runner_context.py | 16 +-- .../runtime/flink_runner_context.py | 105 ++++++++++-------- .../test_flink_runner_context_reconcilable.py | 30 +++-- 3 files changed, 86 insertions(+), 65 deletions(-) diff --git a/python/flink_agents/api/runner_context.py b/python/flink_agents/api/runner_context.py index 5f3daebdb..bf8d4ff9e 100644 --- a/python/flink_agents/api/runner_context.py +++ b/python/flink_agents/api/runner_context.py @@ -218,10 +218,10 @@ def durable_execute( invocation has been persisted yet. The reconciler may: * return a result to provide the recovered successful outcome for this - durable call - * raise an exception if it cannot provide a successful outcome; the - exception is propagated to the caller and no recovered terminal - outcome is persisted for this durable call + durable call; The runtime persists and replays that recovered result + * raise an exception to provide the recovered failed outcome for this + durable call; The runtime persists and replays that recovered + failure Usage:: @@ -272,10 +272,10 @@ def durable_execute_async( invocation has been persisted yet. The reconciler may: * return a result to provide the recovered successful outcome for this - durable call - * raise an exception if it cannot provide a successful outcome; the - exception is propagated to the caller and no recovered terminal - outcome is persisted for this durable call + durable call; The runtime persists and replays that recovered result + * raise an exception to provide the recovered failed outcome for this + durable call; The runtime persists and replays that recovered + failure Usage:: diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index 2f085872c..4daeb6279 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -19,7 +19,8 @@ import os from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import Any, Callable, Dict +from functools import partial +from typing import Any, Callable, Dict, Literal import cloudpickle from typing_extensions import override @@ -67,10 +68,11 @@ class _PersistedCallResult: @dataclass(frozen=True) -class _ReconcilerCallResolution: - kind: str - result: Any = None - exception: BaseException | None = None +class _ReconcilerExecutionPlan: + mode: Literal["replay", "execute"] + callable: Callable[[], Any] | None = None + needs_clear: bool = False + needs_append_pending: bool = False class _DurableExecutionResult: @@ -195,35 +197,27 @@ def __init__( self._reconciler = reconciler def __await__(self) -> Any: - resolution = self._ctx._resolve_reconciler_call( + plan = self._ctx._plan_reconciler_execution( self._func, self._args, self._reconciler, self._kwargs, ) - if resolution.kind == "replay": + if plan.mode == "replay": result = self._ctx._replay_terminal_call(self._func, self._args, self._kwargs) if False: yield return result - if resolution.kind == "finalize_success": - self._ctx._finalize_current_call( - self._func, - self._args, - self._kwargs, - resolution.result, - None, - ) - if False: - yield - return resolution.result - - if resolution.kind == "raise": - raise resolution.exception + self._ctx._prepare_reconciler_execution( + plan, + self._func, + self._args, + self._kwargs, + ) - future = self._executor.submit(self._func, *self._args, **self._kwargs) + future = self._executor.submit(plan.callable) while not future.done(): yield @@ -541,41 +535,56 @@ def _replay_terminal_call(self, func: Callable, args: tuple, kwargs: dict) -> An raise RuntimeError(err_msg) return cached_result - def _resolve_reconciler_call( + def _plan_reconciler_execution( self, func: Callable, args: tuple, reconciler: Callable[[], Any], kwargs: dict, - ) -> _ReconcilerCallResolution: + ) -> _ReconcilerExecutionPlan: function_id = _compute_function_id(func) args_digest = _compute_args_digest(args, kwargs) current = self._peek_current_call_result() + durable_call = partial(func, *args, **kwargs) if current is None: - self._append_pending_call(func, args, kwargs) - return _ReconcilerCallResolution("execute") + return _ReconcilerExecutionPlan( + "execute", + callable=durable_call, + needs_append_pending=True, + ) if current.function_id != function_id or current.args_digest != args_digest: - self._clear_call_results_from_current_index_and_persist() - self._append_pending_call(func, args, kwargs) - return _ReconcilerCallResolution("execute") + return _ReconcilerExecutionPlan( + "execute", + callable=durable_call, + needs_clear=True, + needs_append_pending=True, + ) if current.status != "PENDING": - return _ReconcilerCallResolution("replay") + return _ReconcilerExecutionPlan("replay") - try: - result = reconciler() - except Exception as exception: - return _ReconcilerCallResolution( - "raise", - exception=exception, - ) + return _ReconcilerExecutionPlan( + "execute", + callable=reconciler, + ) - return _ReconcilerCallResolution("finalize_success", result=result) + def _prepare_reconciler_execution( + self, + plan: _ReconcilerExecutionPlan, + func: Callable, + args: tuple, + kwargs: dict, + ) -> None: + if plan.needs_clear: + self._clear_call_results_from_current_index_and_persist() + if plan.needs_append_pending: + self._append_pending_call(func, args, kwargs) def _execute_current_pending_call( self, + execution_callable: Callable[[], Any], func: Callable, args: tuple, kwargs: dict, @@ -583,7 +592,7 @@ def _execute_current_pending_call( exception = None result = None try: - result = func(*args, **kwargs) + result = execution_callable() except BaseException as e: exception = e @@ -638,20 +647,22 @@ def durable_execute( validated_reconciler = _validate_reconciler_callable(reconciler) if validated_reconciler is not None: - resolution = self._resolve_reconciler_call( + plan = self._plan_reconciler_execution( func, args, validated_reconciler, kwargs, ) - if resolution.kind == "replay": + if plan.mode == "replay": return self._replay_terminal_call(func, args, kwargs) - if resolution.kind == "finalize_success": - self._finalize_current_call(func, args, kwargs, resolution.result, None) - return resolution.result - if resolution.kind == "raise": - raise resolution.exception - return self._execute_current_pending_call(func, args, kwargs) + + self._prepare_reconciler_execution(plan, func, args, kwargs) + return self._execute_current_pending_call( + plan.callable, + func, + args, + kwargs, + ) # Try to get cached result for recovery is_hit, cached_result = self._try_get_cached_result(func, args, kwargs) diff --git a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py index 9335aa15a..4d64eae41 100644 --- a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py +++ b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py @@ -222,8 +222,8 @@ def tracked_call(value: str) -> str: ) -def test_flink_runner_context_sync_reconciler_exception_propagates() -> None: - """Propagate reconciler exceptions without finalizing the pending slot.""" +def test_flink_runner_context_sync_reconciler_exception_persists_failure() -> None: + """Persist a recovered failure from the reconciler and re-raise it.""" j_runner_context = _FakeJavaRunnerContext() call_count = 0 @@ -246,9 +246,14 @@ def reconciler() -> str: _close_runner_context(ctx) assert call_count == 0 - assert j_runner_context.operations == ["peek"] - assert j_runner_context.call_results[0].status == "PENDING" - assert j_runner_context.current_call_index == 0 + assert j_runner_context.operations == ["peek", "finalize"] + assert j_runner_context.call_results[0].status == "FAILED" + persisted_exception = cloudpickle.loads( + j_runner_context.call_results[0].exception_payload + ) + assert isinstance(persisted_exception, ValueError) + assert str(persisted_exception) == "failed:order-1" + assert j_runner_context.current_call_index == 1 def test_flink_runner_context_sync_reconciler_mismatch_clears_and_executes() -> None: @@ -350,8 +355,8 @@ def tracked_call(value: str) -> str: assert j_runner_context.operations == ["peek", "finalize"] -def test_flink_runner_context_async_reconciler_exception_propagates() -> None: - """Propagate async reconciler exceptions without advancing state.""" +def test_flink_runner_context_async_reconciler_exception_persists_failure() -> None: + """Persist an async reconciler failure and re-raise it.""" j_runner_context = _FakeJavaRunnerContext() call_count = 0 @@ -379,9 +384,14 @@ def reconciler() -> str: _close_runner_context(ctx) assert call_count == 0 - assert j_runner_context.operations == ["peek"] - assert j_runner_context.call_results[0].status == "PENDING" - assert j_runner_context.current_call_index == 0 + assert j_runner_context.operations == ["peek", "finalize"] + assert j_runner_context.call_results[0].status == "FAILED" + persisted_exception = cloudpickle.loads( + j_runner_context.call_results[0].exception_payload + ) + assert isinstance(persisted_exception, RuntimeError) + assert str(persisted_exception) == "reconcile unavailable" + assert j_runner_context.current_call_index == 1 def test_flink_runner_context_reconciler_kwarg_is_not_forwarded() -> None: