diff --git a/python/flink_agents/api/runner_context.py b/python/flink_agents/api/runner_context.py index 7dfd54376..bf8d4ff9e 100644 --- a/python/flink_agents/api/runner_context.py +++ b/python/flink_agents/api/runner_context.py @@ -196,6 +196,7 @@ def durable_execute( self, func: Callable[[Any], Any], *args: Any, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> Any: """Synchronously execute the provided function with durable execution support. @@ -212,6 +213,16 @@ def durable_execute( will always make the durable_execute call with the same arguments and in the same order during job recovery. Otherwise, the behavior is undefined. + If `reconciler` is provided, recovery invokes it only when revisiting + this durable call and no terminal outcome from the previous durable + invocation has been persisted yet. The reconciler may: + + * return a result to provide the recovered successful outcome for this + durable call; The runtime persists and replays that recovered result + * raise an exception to provide the recovered failed outcome for this + durable call; The runtime persists and replays that recovered + failure + Usage:: def my_action(event, ctx): @@ -224,6 +235,10 @@ def my_action(event, ctx): The function to be executed. *args : Any Positional arguments to pass to the function. + reconciler : Callable[[], Any] | None + Optional zero-argument reconciler callable used only during recovery. + This is a reserved keyword-only parameter and is not forwarded to + `func`. **kwargs : Any Keyword arguments to pass to the function. @@ -238,6 +253,7 @@ def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> "AsyncExecutionResult": """Asynchronously execute the provided function with durable execution support. @@ -251,6 +267,16 @@ def durable_execute_async( will always make the durable_execute_async call with the same arguments and in the same order during job recovery. Otherwise, the behavior is undefined. + If `reconciler` is provided, recovery invokes it only when revisiting + this durable call and no terminal outcome from the previous durable + invocation has been persisted yet. The reconciler may: + + * return a result to provide the recovered successful outcome for this + durable call; The runtime persists and replays that recovered result + * raise an exception to provide the recovered failed outcome for this + durable call; The runtime persists and replays that recovered + failure + Usage:: async def my_action(event, ctx): @@ -267,6 +293,10 @@ async def my_action(event, ctx): The function to be executed asynchronously. *args : Any Positional arguments to pass to the function. + reconciler : Callable[[], Any] | None + Optional zero-argument reconciler callable used only during recovery. + This is a reserved keyword-only parameter and is not forwarded to + `func`. **kwargs : Any Keyword arguments to pass to the function. diff --git a/python/flink_agents/runtime/durable_execution.py b/python/flink_agents/runtime/durable_execution.py new file mode 100644 index 000000000..9db22bfa9 --- /dev/null +++ b/python/flink_agents/runtime/durable_execution.py @@ -0,0 +1,75 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import hashlib +import inspect +from typing import Any, Callable + +import cloudpickle + + +def _compute_function_id(func: Callable) -> str: + """Compute a stable function identifier from a callable.""" + module_obj = inspect.getmodule(func) + module = ( + module_obj.__name__ + if module_obj is not None + else getattr(func, "__module__", "") + ) + qualname = getattr(func, "__qualname__", getattr(func, "__name__", "")) + return f"{module}.{qualname}" + + +def _compute_args_digest(args: tuple, kwargs: dict) -> str: + """Compute a stable digest of the serialized arguments.""" + try: + serialized = cloudpickle.dumps((args, kwargs)) + return hashlib.sha256(serialized).hexdigest()[:16] + except Exception: + return hashlib.sha256(str((args, kwargs)).encode()).hexdigest()[:16] + + +def _can_bind_call( + func: Callable, + *args: Any, + **kwargs: Any, +) -> bool: + """Return whether the callable signature can bind the provided arguments.""" + try: + inspect.signature(func).bind(*args, **kwargs) + except (TypeError, ValueError): + return False + else: + return True + + +def _validate_reconciler_callable( + reconciler: Callable[[], Any] | None, +) -> Callable[[], Any] | None: + """Validate that the reconciler callable is either absent or zero-argument.""" + if reconciler is None: + return None + + if not callable(reconciler): + err_msg = "reconciler must be callable" + raise TypeError(err_msg) + + if not _can_bind_call(reconciler): + err_msg = "reconciler must be a callable that takes no arguments" + raise TypeError(err_msg) + + return reconciler diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index 631267fae..4daeb6279 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -15,11 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -import hashlib import logging import os from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Dict, Literal import cloudpickle from typing_extensions import override @@ -34,8 +35,15 @@ from flink_agents.api.memory_object import MemoryType from flink_agents.api.metric_group import MetricGroup from flink_agents.api.resource import Resource, ResourceType -from flink_agents.api.runner_context import AsyncExecutionResult, RunnerContext -from flink_agents.plan.agent_plan import AgentPlan +from flink_agents.api.runner_context import ( + AsyncExecutionResult, + RunnerContext, +) +from flink_agents.runtime.durable_execution import ( + _compute_args_digest, + _compute_function_id, + _validate_reconciler_callable, +) from flink_agents.runtime.flink_memory_object import FlinkMemoryObject from flink_agents.runtime.flink_metric_group import FlinkMetricGroup from flink_agents.runtime.memory.internal_base_long_term_memory import ( @@ -50,6 +58,23 @@ logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class _PersistedCallResult: + function_id: str + args_digest: str + status: str + result_payload: bytes | None + exception_payload: bytes | None + + +@dataclass(frozen=True) +class _ReconcilerExecutionPlan: + mode: Literal["replay", "execute"] + callable: Callable[[], Any] | None = None + needs_clear: bool = False + needs_append_pending: bool = False + + class _DurableExecutionResult: """Wrapper that holds result and triggers recording when unwrapped.""" @@ -155,28 +180,65 @@ def __await__(self) -> Any: return result -def _compute_function_id(func: Callable) -> str: - """Compute a stable function identifier from a callable. +class _ReconcilerDurableAsyncExecutionResult(AsyncExecutionResult): + """An AsyncExecutionResult that resolves reconciler state on await.""" - Returns module.qualname for functions/methods. - """ - module = getattr(func, "__module__", "") - qualname = getattr(func, "__qualname__", getattr(func, "__name__", "")) - return f"{module}.{qualname}" + def __init__( + self, + ctx: "FlinkRunnerContext", + executor: Any, + func: Callable, + args: tuple, + reconciler: Callable[[], Any], + kwargs: dict, + ) -> None: + super().__init__(executor, func, args, kwargs) + self._ctx = ctx + self._reconciler = reconciler + + def __await__(self) -> Any: + plan = self._ctx._plan_reconciler_execution( + self._func, + self._args, + self._reconciler, + self._kwargs, + ) + if plan.mode == "replay": + result = self._ctx._replay_terminal_call(self._func, self._args, self._kwargs) + if False: + yield + return result -def _compute_args_digest(args: tuple, kwargs: dict) -> str: - """Compute a stable digest of the serialized arguments. + self._ctx._prepare_reconciler_execution( + plan, + self._func, + self._args, + self._kwargs, + ) - The digest is used to validate that the same arguments are passed - during recovery as during the original execution. - """ - try: - serialized = cloudpickle.dumps((args, kwargs)) - return hashlib.sha256(serialized).hexdigest()[:16] - except Exception: - # If serialization fails, return a fallback digest - return hashlib.sha256(str((args, kwargs)).encode()).hexdigest()[:16] + future = self._executor.submit(plan.callable) + while not future.done(): + yield + + exception = None + result = None + try: + result = future.result() + except BaseException as e: + exception = e + + self._ctx._finalize_current_call( + self._func, + self._args, + self._kwargs, + result, + exception, + ) + + if exception is not None: + raise exception + return result class FlinkRunnerContext(RunnerContext): @@ -186,7 +248,7 @@ class FlinkRunnerContext(RunnerContext): durable execution support through execute() and execute_async() methods. """ - __agent_plan: AgentPlan | None + __agent_plan: Any __ltm: InternalBaseLongTermMemory = None def __init__( @@ -203,6 +265,8 @@ def __init__( j_runner_context : Any Java runner context used to synchronize data between Python and Java. """ + from flink_agents.plan.agent_plan import AgentPlan + self._j_runner_context = j_runner_context self.__agent_plan = AgentPlan.model_validate_json(agent_plan_json) self.__resource_cache = ResourceCache( @@ -409,11 +473,165 @@ def _record_call_completion( if "recordCallCompletion" not in str(e): logger.warning("Failed to record call completion: %s", e) + @staticmethod + def _serialize_call_payloads( + result: Any, + exception: BaseException | None, + ) -> tuple[bytes | None, bytes | None]: + result_payload = None if exception else cloudpickle.dumps(result) + exception_payload = cloudpickle.dumps(exception) if exception else None + return result_payload, exception_payload + + def _peek_current_call_result(self) -> _PersistedCallResult | None: + current = self._j_runner_context.getCurrentCallResultFields() + if current is None: + return None + + function_id, args_digest, status, result_payload, exception_payload = current + return _PersistedCallResult( + function_id=function_id, + args_digest=args_digest, + status=status, + result_payload=bytes(result_payload) if result_payload is not None else None, + exception_payload=( + bytes(exception_payload) if exception_payload is not None else None + ), + ) + + def _append_pending_call(self, func: Callable, args: tuple, kwargs: dict) -> None: + self._j_runner_context.appendPendingCall( + _compute_function_id(func), + _compute_args_digest(args, kwargs), + ) + + def _finalize_current_call( + self, + func: Callable, + args: tuple, + kwargs: dict, + result: Any, + exception: BaseException | None, + ) -> None: + function_id = _compute_function_id(func) + args_digest = _compute_args_digest(args, kwargs) + result_payload, exception_payload = self._serialize_call_payloads( + result, + exception, + ) + self._j_runner_context.finalizeCurrentCall( + function_id, + args_digest, + result_payload, + exception_payload, + ) + + def _clear_call_results_from_current_index_and_persist(self) -> None: + self._j_runner_context.clearCallResultsFromCurrentIndexAndPersist() + + def _replay_terminal_call(self, func: Callable, args: tuple, kwargs: dict) -> Any: + is_hit, cached_result = self._try_get_cached_result(func, args, kwargs) + if not is_hit: + err_msg = "Expected a terminal durable call result but replay did not hit" + raise RuntimeError(err_msg) + return cached_result + + def _plan_reconciler_execution( + self, + func: Callable, + args: tuple, + reconciler: Callable[[], Any], + kwargs: dict, + ) -> _ReconcilerExecutionPlan: + function_id = _compute_function_id(func) + args_digest = _compute_args_digest(args, kwargs) + current = self._peek_current_call_result() + durable_call = partial(func, *args, **kwargs) + + if current is None: + return _ReconcilerExecutionPlan( + "execute", + callable=durable_call, + needs_append_pending=True, + ) + + if current.function_id != function_id or current.args_digest != args_digest: + return _ReconcilerExecutionPlan( + "execute", + callable=durable_call, + needs_clear=True, + needs_append_pending=True, + ) + + if current.status != "PENDING": + return _ReconcilerExecutionPlan("replay") + + return _ReconcilerExecutionPlan( + "execute", + callable=reconciler, + ) + + def _prepare_reconciler_execution( + self, + plan: _ReconcilerExecutionPlan, + func: Callable, + args: tuple, + kwargs: dict, + ) -> None: + if plan.needs_clear: + self._clear_call_results_from_current_index_and_persist() + if plan.needs_append_pending: + self._append_pending_call(func, args, kwargs) + + def _execute_current_pending_call( + self, + execution_callable: Callable[[], Any], + func: Callable, + args: tuple, + kwargs: dict, + ) -> Any: + exception = None + result = None + try: + result = execution_callable() + except BaseException as e: + exception = e + + self._finalize_current_call(func, args, kwargs, result, exception) + + if exception is not None: + raise exception + return result + + def _wrap_completion_only_func( + self, + func: Callable, + args: tuple, + kwargs: dict, + ) -> Callable[..., Any]: + def wrapped_func(*a: Any, **kw: Any) -> Any: + exception = None + result = None + try: + result = func(*a, **kw) + except BaseException as e: + exception = e + + if exception: + raise _DurableExecutionException( + func, args, kwargs, result, exception, self._record_call_completion + ) + return _DurableExecutionResult( + func, args, kwargs, result, self._record_call_completion + ) + + return wrapped_func + @override def durable_execute( self, func: Callable[[Any], Any], *args: Any, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> Any: """Synchronously execute the provided function with durable execution support. @@ -426,6 +644,26 @@ def durable_execute( The function is executed synchronously in the current thread, blocking the operator until completion. """ + validated_reconciler = _validate_reconciler_callable(reconciler) + + if validated_reconciler is not None: + plan = self._plan_reconciler_execution( + func, + args, + validated_reconciler, + kwargs, + ) + if plan.mode == "replay": + return self._replay_terminal_call(func, args, kwargs) + + self._prepare_reconciler_execution(plan, func, args, kwargs) + return self._execute_current_pending_call( + plan.callable, + func, + args, + kwargs, + ) + # Try to get cached result for recovery is_hit, cached_result = self._try_get_cached_result(func, args, kwargs) if is_hit: @@ -451,6 +689,7 @@ def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> AsyncExecutionResult: """Asynchronously execute the provided function with durable execution support. @@ -464,32 +703,30 @@ def durable_execute_async( is awaited. Fire-and-forget calls (not awaiting the result) will NOT be recorded and cannot be recovered. """ + validated_reconciler = _validate_reconciler_callable(reconciler) + + if validated_reconciler is not None: + return _ReconcilerDurableAsyncExecutionResult( + self, + self.executor, + func, + args, + validated_reconciler, + kwargs, + ) + # Try to get cached result for recovery is_hit, cached_result = self._try_get_cached_result(func, args, kwargs) if is_hit: # Return a pre-completed AsyncExecutionResult return _CachedAsyncExecutionResult(cached_result) - # Create a wrapper function that records completion - def wrapped_func(*a: Any, **kw: Any) -> Any: - exception = None - result = None - try: - result = func(*a, **kw) - except BaseException as e: - exception = e - - # Note: This runs in a thread pool, so we need to be careful - # The actual recording will happen when the result is awaited - if exception: - raise _DurableExecutionException( - func, args, kwargs, result, exception, self._record_call_completion - ) - return _DurableExecutionResult( - func, args, kwargs, result, self._record_call_completion - ) - - return _DurableAsyncExecutionResult(self.executor, wrapped_func, args, kwargs) + return _DurableAsyncExecutionResult( + self.executor, + self._wrap_completion_only_func(func, args, kwargs), + args, + kwargs, + ) @property @override diff --git a/python/flink_agents/runtime/local_runner.py b/python/flink_agents/runtime/local_runner.py index 75a73efb2..53823b61d 100644 --- a/python/flink_agents/runtime/local_runner.py +++ b/python/flink_agents/runtime/local_runner.py @@ -20,7 +20,7 @@ import uuid from collections import deque from concurrent.futures import Future -from typing import Any, Callable, Dict, List +from typing import TYPE_CHECKING, Any, Callable, Dict, List from typing_extensions import override @@ -31,12 +31,14 @@ from flink_agents.api.metric_group import MetricGroup from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.runner_context import AsyncExecutionResult, RunnerContext -from flink_agents.plan.agent_plan import AgentPlan from flink_agents.plan.configuration import AgentConfiguration from flink_agents.runtime.agent_runner import AgentRunner from flink_agents.runtime.local_memory_object import LocalMemoryObject from flink_agents.runtime.resource_cache import ResourceCache +if TYPE_CHECKING: + from flink_agents.plan.agent_plan import AgentPlan + logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) @@ -58,7 +60,7 @@ class LocalRunnerContext(RunnerContext): Name of the action being executed. """ - __agent_plan: AgentPlan | None + __agent_plan: Any __key: Any events: deque[Event] action_name: str @@ -69,7 +71,7 @@ class LocalRunnerContext(RunnerContext): _config: AgentConfiguration def __init__( - self, agent_plan: AgentPlan, key: Any, config: AgentConfiguration + self, agent_plan: "AgentPlan", key: Any, config: AgentConfiguration ) -> None: """Initialize a new context with the given agent and key. @@ -190,6 +192,7 @@ def durable_execute( self, func: Callable[[Any], Any], *args: Any, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> Any: """Synchronously execute the provided function. Access to memory @@ -208,6 +211,7 @@ def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, + reconciler: Callable[[], Any] | None = None, **kwargs: Any, ) -> AsyncExecutionResult: """Asynchronously execute the provided function. Access to memory @@ -271,7 +275,7 @@ class LocalRunner(AgentRunner): Internal configration. """ - __agent_plan: AgentPlan + __agent_plan: Any __keyed_contexts: Dict[Any, LocalRunnerContext] __outputs: List[Dict[str, Any]] __config: AgentConfiguration @@ -284,6 +288,8 @@ def __init__(self, agent: Agent, config: AgentConfiguration) -> None: agent : Agent The agent class to convert and run. """ + from flink_agents.plan.agent_plan import AgentPlan + self.__agent_plan = AgentPlan.from_agent(agent, config) self.__keyed_contexts = {} self.__outputs = [] diff --git a/python/flink_agents/runtime/tests/test_durable_execution.py b/python/flink_agents/runtime/tests/test_durable_execution.py index 080ad846f..52d777a2c 100644 --- a/python/flink_agents/runtime/tests/test_durable_execution.py +++ b/python/flink_agents/runtime/tests/test_durable_execution.py @@ -18,10 +18,12 @@ """Tests for durable execution helper functions.""" import cloudpickle +import pytest -from flink_agents.runtime.flink_runner_context import ( +from flink_agents.runtime.durable_execution import ( _compute_args_digest, _compute_function_id, + _validate_reconciler_callable, ) @@ -48,6 +50,22 @@ def class_method(cls, x: int) -> int: return x * 4 +class ReconcilerCallables: + """Helpers for reconciler callable validation tests.""" + + def __init__(self, prefix: str) -> None: + """Store a prefix used by the helper callables.""" + self.prefix = prefix + + def bound_no_arg(self) -> str: + """Return a bound zero-argument reconciler result.""" + return f"bound:{self.prefix}" + + def requires_arg(self, value: int) -> str: + """Require an argument so validation can reject the callable.""" + return f"{self.prefix}:{value}" + + def test_compute_function_id_for_function() -> None: """Test function ID computation for regular functions.""" func_id = _compute_function_id(sample_function) @@ -127,6 +145,49 @@ def test_compute_args_digest_kwargs_vs_args() -> None: assert digest1 != digest2 +def test_validate_reconciler_callable_accepts_none() -> None: + """Allow omitting the reconciler callable.""" + assert _validate_reconciler_callable(None) is None + + +def test_validate_reconciler_callable_accepts_zero_arg_function() -> None: + """Accept a zero-argument reconciler function.""" + def reconciler() -> str: + return "ok" + + validated = _validate_reconciler_callable(reconciler) + + assert validated is reconciler + assert validated() == "ok" + + +def test_validate_reconciler_callable_accepts_bound_zero_arg_method() -> None: + """Accept a bound reconciler method with no remaining arguments.""" + callables = ReconcilerCallables("client") + bound_method = callables.bound_no_arg + + validated = _validate_reconciler_callable(bound_method) + + assert validated is bound_method + assert validated() == "bound:client" + + +def test_validate_reconciler_callable_requires_callable() -> None: + """Reject non-callable reconciler values.""" + with pytest.raises(TypeError, match="reconciler must be callable"): + _validate_reconciler_callable(1) # type: ignore[arg-type] + + +def test_validate_reconciler_callable_requires_zero_args() -> None: + """Reject reconciler callables that require arguments.""" + callables = ReconcilerCallables("client") + + with pytest.raises( + TypeError, match="reconciler must be a callable that takes no arguments" + ): + _validate_reconciler_callable(callables.requires_arg) + + def test_cloudpickle_serialization() -> None: """Test that results can be serialized and deserialized with cloudpickle.""" # Test basic types @@ -216,4 +277,3 @@ def test_cloudpickle_none_exception_message() -> None: assert isinstance(deserialized, RuntimeError) # str() of an exception with None message is "None" assert str(deserialized) == "None" - diff --git a/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py new file mode 100644 index 000000000..4d64eae41 --- /dev/null +++ b/python/flink_agents/runtime/tests/test_flink_runner_context_reconcilable.py @@ -0,0 +1,410 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import asyncio +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Any, Callable + +import cloudpickle +import pytest + +from flink_agents.runtime.durable_execution import ( + _compute_args_digest, + _compute_function_id, +) +from flink_agents.runtime.flink_runner_context import FlinkRunnerContext + + +@dataclass +class _StoredCallResult: + function_id: str + args_digest: str + status: str + result_payload: bytes | None = None + exception_payload: bytes | None = None + + +class _FakeJavaRunnerContext: + def __init__(self) -> None: + self.call_results: list[_StoredCallResult] = [] + self.current_call_index = 0 + self.operations: list[str] = [] + + def getCurrentCallResultFields(self) -> list[Any] | None: + self.operations.append("peek") + if self.current_call_index < len(self.call_results): + current = self.call_results[self.current_call_index] + return [ + current.function_id, + current.args_digest, + current.status, + current.result_payload, + current.exception_payload, + ] + return None + + def matchNextOrClearSubsequentCallResult( + self, function_id: str, args_digest: str + ) -> list[Any] | None: + self.operations.append("match") + if self.current_call_index < len(self.call_results): + current = self.call_results[self.current_call_index] + if ( + current.function_id == function_id + and current.args_digest == args_digest + ): + self.current_call_index += 1 + return [True, current.result_payload, current.exception_payload] + self.call_results = self.call_results[: self.current_call_index] + return None + + def recordCallCompletion( + self, + function_id: str, + args_digest: str, + result_payload: bytes | None, + exception_payload: bytes | None, + ) -> None: + self.operations.append("record") + status = "FAILED" if exception_payload is not None else "SUCCEEDED" + self.call_results.append( + _StoredCallResult( + function_id=function_id, + args_digest=args_digest, + status=status, + result_payload=result_payload, + exception_payload=exception_payload, + ) + ) + self.current_call_index += 1 + + def appendPendingCall(self, function_id: str, args_digest: str) -> None: + self.operations.append("append_pending") + self.call_results.append( + _StoredCallResult( + function_id=function_id, + args_digest=args_digest, + status="PENDING", + ) + ) + + def finalizeCurrentCall( + self, + function_id: str, + args_digest: str, + result_payload: bytes | None, + exception_payload: bytes | None, + ) -> None: + self.operations.append("finalize") + current = self.call_results[self.current_call_index] + assert current.status == "PENDING" + assert current.function_id == function_id + assert current.args_digest == args_digest + self.call_results[self.current_call_index] = _StoredCallResult( + function_id=function_id, + args_digest=args_digest, + status="FAILED" if exception_payload is not None else "SUCCEEDED", + result_payload=result_payload, + exception_payload=exception_payload, + ) + self.current_call_index += 1 + + def clearCallResultsFromCurrentIndexAndPersist(self) -> None: + self.operations.append("clear") + self.call_results = self.call_results[: self.current_call_index] + + +def _create_runner_context( + j_runner_context: _FakeJavaRunnerContext, +) -> FlinkRunnerContext: + ctx = FlinkRunnerContext.__new__(FlinkRunnerContext) + ctx._j_runner_context = j_runner_context + ctx.executor = ThreadPoolExecutor(max_workers=1) + ctx._FlinkRunnerContext__agent_plan = None + ctx._FlinkRunnerContext__ltm = None + return ctx + + +def _close_runner_context(ctx: FlinkRunnerContext) -> None: + ctx.executor.shutdown(wait=True) + + +def _run_async(result: Any) -> object: + async def _await_result() -> Any: + return await result + + return asyncio.run(_await_result()) + + +def _preload_pending( + j_runner_context: _FakeJavaRunnerContext, + func: Callable[..., Any], + *args: Any, + **kwargs: Any, +) -> None: + j_runner_context.call_results.append( + _StoredCallResult( + function_id=_compute_function_id(func), + args_digest=_compute_args_digest(args, kwargs), + status="PENDING", + ) + ) + + +def _call_value(value: str) -> str: + return f"call:{value}" + + +def test_flink_runner_context_sync_with_reconciler_executes_original_call() -> None: + """Start a new durable call when no pending state exists.""" + j_runner_context = _FakeJavaRunnerContext() + ctx = _create_runner_context(j_runner_context) + reconciler_called = False + + def reconciler() -> str: + nonlocal reconciler_called + reconciler_called = True + return "reconciled:order-1" + + try: + result = ctx.durable_execute(_call_value, "order-1", reconciler=reconciler) + finally: + _close_runner_context(ctx) + + assert result == "call:order-1" + assert reconciler_called is False + assert j_runner_context.operations == ["peek", "append_pending", "finalize"] + assert j_runner_context.call_results[0].status == "SUCCEEDED" + + +def test_flink_runner_context_sync_reconciler_success() -> None: + """Persist a recovered success without re-executing the original call.""" + j_runner_context = _FakeJavaRunnerContext() + call_count = 0 + + def tracked_call(value: str) -> str: + nonlocal call_count + call_count += 1 + return _call_value(value) + + _preload_pending(j_runner_context, tracked_call, "order-1") + ctx = _create_runner_context(j_runner_context) + + try: + result = ctx.durable_execute( + tracked_call, + "order-1", + reconciler=lambda: "reconciled:order-1", + ) + finally: + _close_runner_context(ctx) + + assert result == "reconciled:order-1" + assert call_count == 0 + assert j_runner_context.operations == ["peek", "finalize"] + assert cloudpickle.loads(j_runner_context.call_results[0].result_payload) == ( + "reconciled:order-1" + ) + + +def test_flink_runner_context_sync_reconciler_exception_persists_failure() -> None: + """Persist a recovered failure from the reconciler and re-raise it.""" + j_runner_context = _FakeJavaRunnerContext() + call_count = 0 + + def tracked_call(value: str) -> str: + nonlocal call_count + call_count += 1 + return _call_value(value) + + _preload_pending(j_runner_context, tracked_call, "order-1") + ctx = _create_runner_context(j_runner_context) + + def reconciler() -> str: + error_message = "failed:order-1" + raise ValueError(error_message) + + try: + with pytest.raises(ValueError, match="failed:order-1"): + ctx.durable_execute(tracked_call, "order-1", reconciler=reconciler) + finally: + _close_runner_context(ctx) + + assert call_count == 0 + assert j_runner_context.operations == ["peek", "finalize"] + assert j_runner_context.call_results[0].status == "FAILED" + persisted_exception = cloudpickle.loads( + j_runner_context.call_results[0].exception_payload + ) + assert isinstance(persisted_exception, ValueError) + assert str(persisted_exception) == "failed:order-1" + assert j_runner_context.current_call_index == 1 + + +def test_flink_runner_context_sync_reconciler_mismatch_clears_and_executes() -> None: + """Clear mismatched persisted state before executing the original call.""" + j_runner_context = _FakeJavaRunnerContext() + stale_result_payload = cloudpickle.dumps("stale") + j_runner_context.call_results.extend( + [ + _StoredCallResult( + function_id=_compute_function_id(_call_value), + args_digest=_compute_args_digest(("other-order",), {}), + status="PENDING", + ), + _StoredCallResult( + function_id="stale.function", + args_digest="stale-args", + status="SUCCEEDED", + result_payload=stale_result_payload, + ), + ] + ) + ctx = _create_runner_context(j_runner_context) + reconciler_called = False + + def reconciler() -> str: + nonlocal reconciler_called + reconciler_called = True + return "reconciled:order-1" + + try: + result = ctx.durable_execute(_call_value, "order-1", reconciler=reconciler) + finally: + _close_runner_context(ctx) + + assert result == "call:order-1" + assert reconciler_called is False + assert j_runner_context.operations == ["peek", "clear", "append_pending", "finalize"] + assert len(j_runner_context.call_results) == 1 + assert j_runner_context.call_results[0].function_id == _compute_function_id(_call_value) + assert j_runner_context.call_results[0].args_digest == _compute_args_digest( + ("order-1",), {} + ) + assert j_runner_context.call_results[0].status == "SUCCEEDED" + + +def test_flink_runner_context_async_writes_pending_on_await() -> None: + """Defer pending-state writes for async execution until await time.""" + j_runner_context = _FakeJavaRunnerContext() + ctx = _create_runner_context(j_runner_context) + reconciler_called = False + + def reconciler() -> str: + nonlocal reconciler_called + reconciler_called = True + return "reconciled:order-1" + + try: + async_result = ctx.durable_execute_async( + _call_value, + "order-1", + reconciler=reconciler, + ) + assert j_runner_context.call_results == [] + result = _run_async(async_result) + finally: + _close_runner_context(ctx) + + assert result == "call:order-1" + assert reconciler_called is False + assert j_runner_context.operations == ["peek", "append_pending", "finalize"] + assert j_runner_context.call_results[0].status == "SUCCEEDED" + + +def test_flink_runner_context_async_reconciler_success() -> None: + """Recover a successful async result through the reconciler.""" + j_runner_context = _FakeJavaRunnerContext() + call_count = 0 + + def tracked_call(value: str) -> str: + nonlocal call_count + call_count += 1 + return _call_value(value) + + _preload_pending(j_runner_context, tracked_call, "order-1") + ctx = _create_runner_context(j_runner_context) + + try: + async_result = ctx.durable_execute_async( + tracked_call, + "order-1", + reconciler=lambda: "reconciled:order-1", + ) + result = _run_async(async_result) + finally: + _close_runner_context(ctx) + + assert result == "reconciled:order-1" + assert call_count == 0 + assert j_runner_context.operations == ["peek", "finalize"] + + +def test_flink_runner_context_async_reconciler_exception_persists_failure() -> None: + """Persist an async reconciler failure and re-raise it.""" + j_runner_context = _FakeJavaRunnerContext() + call_count = 0 + + def tracked_call(value: str) -> str: + nonlocal call_count + call_count += 1 + return _call_value(value) + + _preload_pending(j_runner_context, tracked_call, "order-1") + ctx = _create_runner_context(j_runner_context) + + def reconciler() -> str: + error_message = "reconcile unavailable" + raise RuntimeError(error_message) + + try: + async_result = ctx.durable_execute_async( + tracked_call, + "order-1", + reconciler=reconciler, + ) + with pytest.raises(RuntimeError, match="reconcile unavailable"): + _run_async(async_result) + finally: + _close_runner_context(ctx) + + assert call_count == 0 + assert j_runner_context.operations == ["peek", "finalize"] + assert j_runner_context.call_results[0].status == "FAILED" + persisted_exception = cloudpickle.loads( + j_runner_context.call_results[0].exception_payload + ) + assert isinstance(persisted_exception, RuntimeError) + assert str(persisted_exception) == "reconcile unavailable" + assert j_runner_context.current_call_index == 1 + + +def test_flink_runner_context_reconciler_kwarg_is_not_forwarded() -> None: + """Keep the reserved reconciler kwarg out of the user function call.""" + j_runner_context = _FakeJavaRunnerContext() + ctx = _create_runner_context(j_runner_context) + + def collect_kwargs(**kwargs: Any) -> dict[str, Any]: + return kwargs + + try: + result = ctx.durable_execute(collect_kwargs, reconciler=lambda: "unused") + finally: + _close_runner_context(ctx) + + assert result == {} diff --git a/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py b/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py new file mode 100644 index 000000000..a213423f2 --- /dev/null +++ b/python/flink_agents/runtime/tests/test_local_runner_reconcilable.py @@ -0,0 +1,79 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import asyncio +from typing import Any + +from flink_agents.runtime.local_runner import LocalRunnerContext + + +def reconciled_add(x: int, y: int) -> int: + """Return a simple deterministic value for local-runner tests.""" + return x + y + + +def _create_local_runner_context() -> LocalRunnerContext: + return LocalRunnerContext.__new__(LocalRunnerContext) + + +def test_local_runner_context_reconciler_durable_execute_degrades() -> None: + """Keep sync local execution on the existing non-durable path.""" + ctx = _create_local_runner_context() + reconciler_called = False + + def reconciler() -> int: + nonlocal reconciler_called + reconciler_called = True + return 999 + + result = ctx.durable_execute(reconciled_add, 5, 10, reconciler=reconciler) + + assert result == 15 + assert reconciler_called is False + + +def test_local_runner_context_reconciler_durable_execute_async_degrades() -> None: + """Keep async local execution on the existing non-durable path.""" + ctx = _create_local_runner_context() + reconciler_called = False + + def reconciler() -> int: + nonlocal reconciler_called + reconciler_called = True + return 999 + + async_result = ctx.durable_execute_async( + reconciled_add, 5, 10, reconciler=reconciler + ) + + async def _await_result() -> Any: + return await async_result + + assert asyncio.run(_await_result()) == 15 + assert reconciler_called is False + + +def test_local_runner_context_reconciler_kwarg_is_not_forwarded() -> None: + """Do not forward the reserved reconciler kwarg to the user function.""" + ctx = _create_local_runner_context() + + def collect_kwargs(**kwargs: Any) -> dict[str, Any]: + return kwargs + + result = ctx.durable_execute(collect_kwargs, reconciler=lambda: "unused") + + assert result == {} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index f4d28d3b0..3e302854a 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -441,6 +441,24 @@ public void clearCallResultsFromCurrentIndexAndPersist() { } } + /** + * Returns the current durable call result as an array of fields for bridge consumers, or null + * if no persisted slot exists at the current call index. + */ + public Object[] getCurrentCallResultFields() { + CallResult current = getCurrentCallResult(); + if (current == null) { + return null; + } + return new Object[] { + current.getFunctionId(), + current.getArgsDigest(), + current.isPending() ? "PENDING" : current.isFailure() ? "FAILED" : "SUCCEEDED", + current.getResultPayload(), + current.getExceptionPayload() + }; + } + protected CallResult getCurrentCallResult() { mailboxThreadChecker.run(); if (durableExecutionContext != null) {