From 0bf39821ee6bbd348571d50e9a7475ac72ae2c46 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Sun, 5 Apr 2026 22:31:46 -0700 Subject: [PATCH 01/14] [runtime] Extract OperatorStateManager from ActionExecutionOperator Move 7 Flink state fields, 2 constants, and state management methods into a package-private OperatorStateManager class. The operator delegates all state access through the manager. Moved fields: actionTasksKState, pendingInputEventsKState, currentProcessingKeysOpState, sequenceNumberKState, sensoryMemState, shortTermMemState, jobIdentifier. Part of #545. --- .../operator/ActionExecutionOperator.java | 195 +++------------ .../operator/OperatorStateManager.java | 228 ++++++++++++++++++ 2 files changed, 266 insertions(+), 157 deletions(-) create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index e5015a3a5..75fd33a2e 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -49,7 +49,6 @@ import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; import org.apache.flink.agents.runtime.eventlog.FileEventLogger; import org.apache.flink.agents.runtime.memory.CachedMemoryStore; -import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; import org.apache.flink.agents.runtime.operator.queue.SegmentedQueue; @@ -64,18 +63,11 @@ import org.apache.flink.api.common.operators.MailboxExecutor; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.state.MapState; -import org.apache.flink.api.common.state.MapStateDescriptor; -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.python.env.PythonDependencyInfo; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.runtime.state.VoidNamespace; -import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.BoundedOneInput; @@ -104,10 +96,8 @@ import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.BASE_LOG_DIR; -import static org.apache.flink.agents.api.configuration.AgentConfigOptions.JOB_IDENTIFIER; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.PRETTY_PRINT; import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA; -import static org.apache.flink.agents.runtime.utils.StateUtil.*; import static org.apache.flink.util.Preconditions.checkState; /** @@ -128,8 +118,6 @@ public class ActionExecutionOperator extends AbstractStreamOperator extends AbstractStreamOperator reusedStreamRecord; - private transient MapState sensoryMemState; - - private transient MapState shortTermMemState; - private transient PythonEnvironmentManager pythonEnvironmentManager; private transient PythonInterpreter pythonInterpreter; @@ -177,24 +161,12 @@ public class ActionExecutionOperator extends AbstractStreamOperator actionTasksKState; - - // To avoid processing different InputEvents with the same key, we use a state to store pending - // InputEvents that are waiting to be processed. - private transient ListState pendingInputEventsKState; - - // An operator state is used to track the currently processing keys. This is useful when - // receiving an EndOfInput signal, as we need to wait until all related events are fully - // processed. - private transient ListState currentProcessingKeysOpState; + private final transient OperatorStateManager stateManager; private final transient EventLogger eventLogger; private final transient List eventListeners; private transient ActionStateStore actionStateStore; - private transient ValueState sequenceNumberKState; private transient ListState recoveryMarkerOpState; private transient Map> checkpointIdToSeqNums; @@ -212,13 +184,6 @@ public class ActionExecutionOperator extends AbstractStreamOperator pythonAwaitableRefs; - // Each job can only have one identifier and this identifier must be consistent across restarts. - // We cannot use job id as the identifier here because user may change job id by - // creating a savepoint, stop the job and then resume from savepoint. - // We use this identifier to control the visibility for long-term memory. - // Inspired by Apache Paimon. - private transient String jobIdentifier; - private transient ContinuationActionExecutor continuationActionExecutor; public ActionExecutionOperator( @@ -231,6 +196,7 @@ public ActionExecutionOperator( this.inputIsJava = inputIsJava; this.processingTimeService = processingTimeService; this.mailboxExecutor = mailboxExecutor; + this.stateManager = new OperatorStateManager(); this.eventLogger = createEventLogger(agentPlan); this.eventListeners = new ArrayList<>(); this.actionStateStore = actionStateStore; @@ -254,21 +220,9 @@ public void setup( public void open() throws Exception { super.open(); reusedStreamRecord = new StreamRecord<>(null); - // init sensoryMemState - MapStateDescriptor sensoryMemStateDescriptor = - new MapStateDescriptor<>( - "sensoryMemory", - TypeInformation.of(String.class), - TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); - sensoryMemState = getRuntimeContext().getMapState(sensoryMemStateDescriptor); - - // init shortTermMemState - MapStateDescriptor shortTermMemStateDescriptor = - new MapStateDescriptor<>( - "shortTermMemory", - TypeInformation.of(String.class), - TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); - shortTermMemState = getRuntimeContext().getMapState(shortTermMemStateDescriptor); + + stateManager.initializeKeyedStates(getRuntimeContext()); + stateManager.initializeOperatorStates(getOperatorStateBackend()); resourceCache = new ResourceCache(agentPlan.getResourceProviders()); @@ -288,34 +242,6 @@ public void open() throws Exception { RECOVERY_MARKER_STATE_NAME, TypeInformation.of(Object.class))); } - // init sequence number state for per key message ordering - sequenceNumberKState = - getRuntimeContext() - .getState( - new ValueStateDescriptor<>( - MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class)); - - // init agent processing related state - actionTasksKState = - getRuntimeContext() - .getListState( - new ListStateDescriptor<>( - "actionTasks", TypeInformation.of(ActionTask.class))); - pendingInputEventsKState = - getRuntimeContext() - .getListState( - new ListStateDescriptor<>( - PENDING_INPUT_EVENT_STATE_NAME, - TypeInformation.of(Event.class))); - // We use UnionList here to ensure that the task can access all keys after parallelism - // modifications. - // Subsequent steps {@link #tryResumeProcessActionTasks} will then filter out keys that do - // not belong to the key range of current task. - currentProcessingKeysOpState = - getOperatorStateBackend() - .getUnionListState( - new ListStateDescriptor<>( - "currentProcessingKeys", TypeInformation.of(Object.class))); // init PythonActionExecutor and PythonResourceAdapter initPythonEnvironment(); @@ -363,11 +289,11 @@ public void processElement(StreamRecord record) throws Exception { keySegmentQueue.addKeyToLastSegment(getCurrentKey()); - if (currentKeyHasMoreActionTask()) { + if (stateManager.hasMoreActionTasks()) { // If there are already actions being processed for the current key, the newly incoming // event should be queued and processed later. Therefore, we add it to // pendingInputEventsState. - pendingInputEventsKState.add(inputEvent); + stateManager.addPendingInputEvent(inputEvent); } else { // Otherwise, the new event is processed immediately. processEvent(getCurrentKey(), inputEvent); @@ -394,15 +320,15 @@ private void processEvent(Object key, Event event) throws Exception { } else { if (isInputEvent) { // If the event is an InputEvent, we mark that the key is currently being processed. - currentProcessingKeysOpState.add(key); - initOrIncSequenceNumber(); + stateManager.addProcessingKey(key); + stateManager.initOrIncSequenceNumber(); } // We then obtain the triggered action and add ActionTasks to the waiting processing // queue. List triggerActions = getActionsTriggeredBy(event); if (triggerActions != null && !triggerActions.isEmpty()) { for (Action triggerAction : triggerActions) { - actionTasksKState.add(createActionTask(key, triggerAction, event)); + stateManager.addActionTask(createActionTask(key, triggerAction, event)); } } } @@ -449,9 +375,9 @@ private void processActionTaskForKey(Object key) throws Exception { // 1. Get an action task for the key. setCurrentKey(key); - ActionTask actionTask = pollFromListState(actionTasksKState); + ActionTask actionTask = stateManager.pollNextActionTask(); if (actionTask == null) { - int removedCount = removeFromListState(currentProcessingKeysOpState, key); + int removedCount = stateManager.removeProcessingKey(key); checkState( removedCount == 1, "Current processing key count for key " @@ -468,7 +394,7 @@ private void processActionTaskForKey(Object key) throws Exception { // 2. Invoke the action task. createAndSetRunnerContext(actionTask, key); - long sequenceNumber = sequenceNumberKState.value(); + long sequenceNumber = stateManager.getSequenceNumber(); boolean isFinished; List outputEvents; Optional generatedActionTaskOpt = Optional.empty(); @@ -540,7 +466,7 @@ private void processActionTaskForKey(Object key) throws Exception { boolean currentInputEventFinished = false; if (isFinished) { builtInMetrics.markActionExecuted(actionTask.action.getName()); - currentInputEventFinished = !currentKeyHasMoreActionTask(); + currentInputEventFinished = !stateManager.hasMoreActionTasks(); // Persist memory to the Flink state when the action task is finished. actionTask.getRunnerContext().persistMemory(); @@ -577,7 +503,7 @@ private void processActionTaskForKey(Object key) throws Exception { } } - actionTasksKState.add(generatedActionTask); + stateManager.addActionTask(generatedActionTask); } // 3. Process the next InputEvent or next action task @@ -587,7 +513,7 @@ private void processActionTaskForKey(Object key) throws Exception { // Once all sub-events and actions related to the current InputEvent are completed, // we can proceed to process the next InputEvent. - int removedCount = removeFromListState(currentProcessingKeysOpState, key); + int removedCount = stateManager.removeProcessingKey(key); maybePruneState(key, sequenceNumber); checkState( removedCount == 1, @@ -599,11 +525,11 @@ private void processActionTaskForKey(Object key) throws Exception { keySegmentQueue.removeKey(key), "Current key" + key + " is missing from the segmentedQueue."); processEligibleWatermarks(); - Event pendingInputEvent = pollFromListState(pendingInputEventsKState); + Event pendingInputEvent = stateManager.pollNextPendingInputEvent(); if (pendingInputEvent != null) { processEvent(key, pendingInputEvent); } - } else if (currentKeyHasMoreActionTask()) { + } else if (stateManager.hasMoreActionTasks()) { // If the current key has additional action tasks remaining, we should submit a new mail // to continue processing them. mailboxExecutor.submit(() -> tryProcessActionTaskForKey(key), "process action task"); @@ -658,7 +584,7 @@ private void initPythonEnvironment() throws Exception { this::checkMailboxThread, this.agentPlan, this.resourceCache, - this.jobIdentifier); + stateManager.getJobIdentifier()); javaResourceAdapter = new JavaResourceAdapter(this::getResource, pythonInterpreter); if (containPythonResource) { @@ -677,7 +603,7 @@ private void initPythonActionExecutor() throws Exception { agentPlan, javaResourceAdapter, pythonRunnerContext, - jobIdentifier); + stateManager.getJobIdentifier()); pythonActionExecutor.open(); } @@ -705,7 +631,7 @@ public void endInput() throws Exception { @VisibleForTesting public void waitInFlightEventsFinished() throws Exception { - while (listStateNotEmpty(currentProcessingKeysOpState)) { + while (stateManager.hasProcessingKeys()) { mailboxExecutor.yield(); } } @@ -773,15 +699,7 @@ public void initializeState(StateInitializationContext context) throws Exception actionStateStore.rebuildState(markers); } - // Get job identifier from user configuration. - // If not configured, get from state. - jobIdentifier = agentPlan.getConfig().get(JOB_IDENTIFIER); - if (jobIdentifier == null) { - String initialJobIdentifier = getRuntimeContext().getJobInfo().getJobId().toString(); - jobIdentifier = - StateUtils.getSingleValueFromState( - context, "identifier_state", String.class, initialJobIdentifier); - } + stateManager.initJobIdentifier(context, agentPlan, getRuntimeContext()); } @Override @@ -793,14 +711,8 @@ public void snapshotState(StateSnapshotContext context) throws Exception { } } - HashMap keyToSeqNum = new HashMap<>(); - getKeyedStateBackend() - .applyToAllKeys( - VoidNamespace.INSTANCE, - VoidNamespaceSerializer.INSTANCE, - new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class), - (key, state) -> keyToSeqNum.put(key, state.value())); - checkpointIdToSeqNums.put(context.getCheckpointId(), keyToSeqNum); + stateManager.snapshotSequenceNumbers( + getKeyedStateBackend(), checkpointIdToSeqNums, context.getCheckpointId()); super.snapshotState(context); } @@ -893,8 +805,8 @@ private void createAndSetRunnerContext(ActionTask actionTask, Object key) { } else { memoryContext = new RunnerContextImpl.MemoryContext( - new CachedMemoryStore(sensoryMemState), - new CachedMemoryStore(shortTermMemState)); + new CachedMemoryStore(stateManager.getSensoryMemState()), + new CachedMemoryStore(stateManager.getShortTermMemState())); } runnerContext.switchActionContext( @@ -920,18 +832,16 @@ private void createAndSetRunnerContext(ActionTask actionTask, Object key) { actionTask.setRunnerContext(runnerContext); } - private boolean currentKeyHasMoreActionTask() throws Exception { - return listStateNotEmpty(actionTasksKState); - } - private void tryResumeProcessActionTasks() throws Exception { - Iterable keys = currentProcessingKeysOpState.get(); + Iterable keys = stateManager.getProcessingKeys(); if (keys != null) { int maxParallelism = getRuntimeContext().getTaskInfo().getMaxNumberOfParallelSubtasks(); KeyGroupRange currentSubtaskKeyGroupRange = - getCurrentSubtaskKeyGroupRange(maxParallelism); + stateManager.getCurrentSubtaskKeyGroupRange( + maxParallelism, getRuntimeContext()); for (Object key : keys) { - if (!isKeyOwnedByCurrentSubtask(key, maxParallelism, currentSubtaskKeyGroupRange)) { + if (!stateManager.isKeyOwnedByCurrentSubtask( + key, maxParallelism, currentSubtaskKeyGroupRange)) { continue; } keySegmentQueue.addKeyToLastSegment(key); @@ -940,26 +850,10 @@ private void tryResumeProcessActionTasks() throws Exception { } } - getKeyedStateBackend() - .applyToAllKeys( - VoidNamespace.INSTANCE, - VoidNamespaceSerializer.INSTANCE, - new ListStateDescriptor<>( - PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)), - (key, state) -> - state.get() - .forEach( - event -> keySegmentQueue.addKeyToLastSegment(key))); - } - - private void initOrIncSequenceNumber() throws Exception { - // Initialize the sequence number state if it does not exist. - Long sequenceNumber = sequenceNumberKState.value(); - if (sequenceNumber == null) { - sequenceNumberKState.update(0L); - } else { - sequenceNumberKState.update(sequenceNumber + 1); - } + stateManager.forEachPendingInputEventKey( + getKeyedStateBackend(), + (key, state) -> + state.get().forEach(event -> keySegmentQueue.addKeyToLastSegment(key))); } private ActionState maybeGetActionState( @@ -1047,7 +941,7 @@ private void setupDurableExecutionContext(ActionTask actionTask, ActionState act // Create new context for first invocation final long sequenceNumber; try { - sequenceNumber = sequenceNumberKState.value(); + sequenceNumber = stateManager.getSequenceNumber(); } catch (Exception e) { throw new RuntimeException("Failed to get sequence number from state", e); } @@ -1106,7 +1000,7 @@ private RunnerContextImpl createOrGetRunnerContext(Boolean isJava) { this::checkMailboxThread, this.agentPlan, this.resourceCache, - this.jobIdentifier, + stateManager.getJobIdentifier(), continuationActionExecutor); } return runnerContext; @@ -1118,7 +1012,7 @@ private RunnerContextImpl createOrGetRunnerContext(Boolean isJava) { this::checkMailboxThread, this.agentPlan, this.resourceCache, - jobIdentifier); + stateManager.getJobIdentifier()); } return pythonRunnerContext; } @@ -1144,19 +1038,6 @@ private void maybeInitActionStateStore() { } } - private KeyGroupRange getCurrentSubtaskKeyGroupRange(int maxParallelism) { - int parallelism = getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks(); - int subtaskIndex = getRuntimeContext().getTaskInfo().getIndexOfThisSubtask(); - return KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( - maxParallelism, parallelism, subtaskIndex); - } - - private boolean isKeyOwnedByCurrentSubtask( - Object key, int maxParallelism, KeyGroupRange currentSubtaskKeyGroupRange) { - int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism); - return currentSubtaskKeyGroupRange.contains(keyGroup); - } - /** Failed to execute Action task. */ public static class ActionTaskExecutionException extends Exception { public ActionTaskExecutionException(String message, Throwable cause) { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java new file mode 100644 index 000000000..0a20670f3 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.KeyedStateFunction; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.JOB_IDENTIFIER; +import static org.apache.flink.agents.runtime.utils.StateUtil.*; + +class OperatorStateManager { + + static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = "messageSequenceNumber"; + static final String PENDING_INPUT_EVENT_STATE_NAME = "pendingInputEvents"; + + private ListState actionTasksKState; + private ListState pendingInputEventsKState; + private ListState currentProcessingKeysOpState; + private ValueState sequenceNumberKState; + private MapState sensoryMemState; + private MapState shortTermMemState; + private String jobIdentifier; + + OperatorStateManager() {} + + void initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext runtimeContext) + throws Exception { + // init sensoryMemState + MapStateDescriptor sensoryMemStateDescriptor = + new MapStateDescriptor<>( + "sensoryMemory", + TypeInformation.of(String.class), + TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); + sensoryMemState = runtimeContext.getMapState(sensoryMemStateDescriptor); + + // init shortTermMemState + MapStateDescriptor shortTermMemStateDescriptor = + new MapStateDescriptor<>( + "shortTermMemory", + TypeInformation.of(String.class), + TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); + shortTermMemState = runtimeContext.getMapState(shortTermMemStateDescriptor); + + // init sequence number state for per key message ordering + sequenceNumberKState = + runtimeContext.getState( + new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class)); + + // init agent processing related state + actionTasksKState = + runtimeContext.getListState( + new ListStateDescriptor<>( + "actionTasks", TypeInformation.of(ActionTask.class))); + pendingInputEventsKState = + runtimeContext.getListState( + new ListStateDescriptor<>( + PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class))); + } + + void initializeOperatorStates(OperatorStateBackend operatorStateBackend) throws Exception { + // We use UnionList here to ensure that the task can access all keys after parallelism + // modifications. + // Subsequent steps {@link #tryResumeProcessActionTasks} will then filter out keys that do + // not belong to the key range of current task. + currentProcessingKeysOpState = + operatorStateBackend.getUnionListState( + new ListStateDescriptor<>( + "currentProcessingKeys", TypeInformation.of(Object.class))); + } + + void initJobIdentifier( + StateInitializationContext context, + AgentPlan agentPlan, + org.apache.flink.api.common.functions.RuntimeContext runtimeContext) + throws Exception { + // Get job identifier from user configuration. + // If not configured, get from state. + jobIdentifier = agentPlan.getConfig().get(JOB_IDENTIFIER); + if (jobIdentifier == null) { + String initialJobIdentifier = runtimeContext.getJobInfo().getJobId().toString(); + jobIdentifier = + StateUtils.getSingleValueFromState( + context, "identifier_state", String.class, initialJobIdentifier); + } + } + + void initOrIncSequenceNumber() throws Exception { + // Initialize the sequence number state if it does not exist. + Long sequenceNumber = sequenceNumberKState.value(); + if (sequenceNumber == null) { + sequenceNumberKState.update(0L); + } else { + sequenceNumberKState.update(sequenceNumber + 1); + } + } + + long getSequenceNumber() throws Exception { + return sequenceNumberKState.value(); + } + + boolean hasMoreActionTasks() throws Exception { + return listStateNotEmpty(actionTasksKState); + } + + ActionTask pollNextActionTask() throws Exception { + return pollFromListState(actionTasksKState); + } + + void addActionTask(ActionTask actionTask) throws Exception { + actionTasksKState.add(actionTask); + } + + void addPendingInputEvent(Event event) throws Exception { + pendingInputEventsKState.add(event); + } + + Event pollNextPendingInputEvent() throws Exception { + return pollFromListState(pendingInputEventsKState); + } + + void addProcessingKey(Object key) throws Exception { + currentProcessingKeysOpState.add(key); + } + + int removeProcessingKey(Object key) throws Exception { + return removeFromListState(currentProcessingKeysOpState, key); + } + + boolean hasProcessingKeys() throws Exception { + return listStateNotEmpty(currentProcessingKeysOpState); + } + + Iterable getProcessingKeys() throws Exception { + return currentProcessingKeysOpState.get(); + } + + MapState getSensoryMemState() { + return sensoryMemState; + } + + MapState getShortTermMemState() { + return shortTermMemState; + } + + String getJobIdentifier() { + return jobIdentifier; + } + + KeyGroupRange getCurrentSubtaskKeyGroupRange( + int maxParallelism, + org.apache.flink.api.common.functions.RuntimeContext runtimeContext) { + int parallelism = runtimeContext.getTaskInfo().getNumberOfParallelSubtasks(); + int subtaskIndex = runtimeContext.getTaskInfo().getIndexOfThisSubtask(); + return KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( + maxParallelism, parallelism, subtaskIndex); + } + + boolean isKeyOwnedByCurrentSubtask( + Object key, int maxParallelism, KeyGroupRange currentSubtaskKeyGroupRange) { + int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism); + return currentSubtaskKeyGroupRange.contains(keyGroup); + } + + @SuppressWarnings("unchecked") + void snapshotSequenceNumbers( + KeyedStateBackend keyedStateBackend, + Map> checkpointIdToSeqNums, + long checkpointId) + throws Exception { + HashMap keyToSeqNum = new HashMap<>(); + ((KeyedStateBackend) keyedStateBackend) + .applyToAllKeys( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class), + (key, state) -> keyToSeqNum.put(key, state.value())); + checkpointIdToSeqNums.put(checkpointId, keyToSeqNum); + } + + @SuppressWarnings("unchecked") + void forEachPendingInputEventKey( + KeyedStateBackend keyedStateBackend, + KeyedStateFunction> function) + throws Exception { + ((KeyedStateBackend) keyedStateBackend) + .applyToAllKeys( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ListStateDescriptor<>( + PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)), + function); + } +} From 771ec58b3c6d6969d99215b3c44c87ad383b3d7b Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Sun, 5 Apr 2026 22:52:04 -0700 Subject: [PATCH 02/14] [runtime] Extract DurableExecutionManager from ActionExecutionOperator Move action state persistence, recovery markers, checkpoint maps, and durable/continuation context maps into a package-private DurableExecutionManager class. The manager implements ActionStatePersister (moved from the operator). Moved fields: actionStateStore, recoveryMarkerOpState, checkpointIdToSeqNums, actionTaskDurableContexts, continuationContexts, pythonAwaitableRefs. Test reflection accesses to actionStateStore updated to use @VisibleForTesting getter chain. Part of #545. --- .../operator/ActionExecutionOperator.java | 431 ++++++++---------- .../operator/DurableExecutionManager.java | 285 ++++++++++++ .../operator/ActionExecutionOperatorTest.java | 36 +- 3 files changed, 480 insertions(+), 272 deletions(-) create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 75fd33a2e..87a59a68f 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -39,16 +39,15 @@ import org.apache.flink.agents.runtime.ResourceCache; import org.apache.flink.agents.runtime.actionstate.ActionState; import org.apache.flink.agents.runtime.actionstate.ActionStateStore; -import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore; import org.apache.flink.agents.runtime.async.ContinuationActionExecutor; import org.apache.flink.agents.runtime.async.ContinuationContext; -import org.apache.flink.agents.runtime.context.ActionStatePersister; import org.apache.flink.agents.runtime.context.JavaRunnerContextImpl; import org.apache.flink.agents.runtime.context.RunnerContextImpl; import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; import org.apache.flink.agents.runtime.eventlog.FileEventLogger; import org.apache.flink.agents.runtime.memory.CachedMemoryStore; +import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; import org.apache.flink.agents.runtime.operator.queue.SegmentedQueue; @@ -63,11 +62,18 @@ import org.apache.flink.api.common.operators.MailboxExecutor; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.python.env.PythonDependencyInfo; import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.BoundedOneInput; @@ -94,10 +100,10 @@ import java.util.Map; import java.util.Optional; -import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.BASE_LOG_DIR; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.JOB_IDENTIFIER; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.PRETTY_PRINT; -import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA; +import static org.apache.flink.agents.runtime.utils.StateUtil.*; import static org.apache.flink.util.Preconditions.checkState; /** @@ -111,13 +117,14 @@ * and the resulting output event is collected for further processing. */ public class ActionExecutionOperator extends AbstractStreamOperator - implements OneInputStreamOperator, BoundedOneInput, ActionStatePersister { + implements OneInputStreamOperator, BoundedOneInput { private static final long serialVersionUID = 1L; private static final Logger LOG = LoggerFactory.getLogger(ActionExecutionOperator.class); - private static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker"; + private static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = "messageSequenceNumber"; + private static final String PENDING_INPUT_EVENT_STATE_NAME = "pendingInputEvents"; private final AgentPlan agentPlan; @@ -127,6 +134,10 @@ public class ActionExecutionOperator extends AbstractStreamOperator reusedStreamRecord; + private transient MapState sensoryMemState; + + private transient MapState shortTermMemState; + private transient PythonEnvironmentManager pythonEnvironmentManager; private transient PythonInterpreter pythonInterpreter; @@ -161,28 +172,37 @@ public class ActionExecutionOperator extends AbstractStreamOperator actionTasksKState; + + // To avoid processing different InputEvents with the same key, we use a state to store pending + // InputEvents that are waiting to be processed. + private transient ListState pendingInputEventsKState; + + // An operator state is used to track the currently processing keys. This is useful when + // receiving an EndOfInput signal, as we need to wait until all related events are fully + // processed. + private transient ListState currentProcessingKeysOpState; private final transient EventLogger eventLogger; private final transient List eventListeners; - private transient ActionStateStore actionStateStore; - private transient ListState recoveryMarkerOpState; - private transient Map> checkpointIdToSeqNums; + private transient ValueState sequenceNumberKState; + + private final transient DurableExecutionManager durableExecManager; // This in memory map keep track of the runner context for the async action task that having // been finished private final transient Map actionTaskMemoryContexts; - // This in memory map keeps track of the durable execution context for async action tasks - // that have not been finished, allowing recovery of currentCallIndex across invocations - private final transient Map - actionTaskDurableContexts; - - private final transient Map continuationContexts; - - private final transient Map pythonAwaitableRefs; + // Each job can only have one identifier and this identifier must be consistent across restarts. + // We cannot use job id as the identifier here because user may change job id by + // creating a savepoint, stop the job and then resume from savepoint. + // We use this identifier to control the visibility for long-term memory. + // Inspired by Apache Paimon. + private transient String jobIdentifier; private transient ContinuationActionExecutor continuationActionExecutor; @@ -196,15 +216,10 @@ public ActionExecutionOperator( this.inputIsJava = inputIsJava; this.processingTimeService = processingTimeService; this.mailboxExecutor = mailboxExecutor; - this.stateManager = new OperatorStateManager(); this.eventLogger = createEventLogger(agentPlan); this.eventListeners = new ArrayList<>(); - this.actionStateStore = actionStateStore; - this.checkpointIdToSeqNums = new HashMap<>(); + this.durableExecManager = new DurableExecutionManager(actionStateStore); this.actionTaskMemoryContexts = new HashMap<>(); - this.actionTaskDurableContexts = new HashMap<>(); - this.continuationContexts = new HashMap<>(); - this.pythonAwaitableRefs = new HashMap<>(); OperatorUtils.setChainStrategy(this, ChainingStrategy.ALWAYS); } @@ -220,9 +235,21 @@ public void setup( public void open() throws Exception { super.open(); reusedStreamRecord = new StreamRecord<>(null); - - stateManager.initializeKeyedStates(getRuntimeContext()); - stateManager.initializeOperatorStates(getOperatorStateBackend()); + // init sensoryMemState + MapStateDescriptor sensoryMemStateDescriptor = + new MapStateDescriptor<>( + "sensoryMemory", + TypeInformation.of(String.class), + TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); + sensoryMemState = getRuntimeContext().getMapState(sensoryMemStateDescriptor); + + // init shortTermMemState + MapStateDescriptor shortTermMemStateDescriptor = + new MapStateDescriptor<>( + "shortTermMemory", + TypeInformation.of(String.class), + TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); + shortTermMemState = getRuntimeContext().getMapState(shortTermMemStateDescriptor); resourceCache = new ResourceCache(agentPlan.getResourceProviders()); @@ -231,17 +258,36 @@ public void open() throws Exception { keySegmentQueue = new SegmentedQueue(); - maybeInitActionStateStore(); - - if (actionStateStore != null) { - // init recovery marker state for recovery marker persistence - recoveryMarkerOpState = - getOperatorStateBackend() - .getUnionListState( - new ListStateDescriptor<>( - RECOVERY_MARKER_STATE_NAME, - TypeInformation.of(Object.class))); - } + durableExecManager.maybeInitActionStateStore(agentPlan.getConfig()); + durableExecManager.initRecoveryMarkerState(getOperatorStateBackend()); + // init sequence number state for per key message ordering + sequenceNumberKState = + getRuntimeContext() + .getState( + new ValueStateDescriptor<>( + MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class)); + + // init agent processing related state + actionTasksKState = + getRuntimeContext() + .getListState( + new ListStateDescriptor<>( + "actionTasks", TypeInformation.of(ActionTask.class))); + pendingInputEventsKState = + getRuntimeContext() + .getListState( + new ListStateDescriptor<>( + PENDING_INPUT_EVENT_STATE_NAME, + TypeInformation.of(Event.class))); + // We use UnionList here to ensure that the task can access all keys after parallelism + // modifications. + // Subsequent steps {@link #tryResumeProcessActionTasks} will then filter out keys that do + // not belong to the key range of current task. + currentProcessingKeysOpState = + getOperatorStateBackend() + .getUnionListState( + new ListStateDescriptor<>( + "currentProcessingKeys", TypeInformation.of(Object.class))); // init PythonActionExecutor and PythonResourceAdapter initPythonEnvironment(); @@ -289,11 +335,11 @@ public void processElement(StreamRecord record) throws Exception { keySegmentQueue.addKeyToLastSegment(getCurrentKey()); - if (stateManager.hasMoreActionTasks()) { + if (currentKeyHasMoreActionTask()) { // If there are already actions being processed for the current key, the newly incoming // event should be queued and processed later. Therefore, we add it to // pendingInputEventsState. - stateManager.addPendingInputEvent(inputEvent); + pendingInputEventsKState.add(inputEvent); } else { // Otherwise, the new event is processed immediately. processEvent(getCurrentKey(), inputEvent); @@ -320,15 +366,15 @@ private void processEvent(Object key, Event event) throws Exception { } else { if (isInputEvent) { // If the event is an InputEvent, we mark that the key is currently being processed. - stateManager.addProcessingKey(key); - stateManager.initOrIncSequenceNumber(); + currentProcessingKeysOpState.add(key); + initOrIncSequenceNumber(); } // We then obtain the triggered action and add ActionTasks to the waiting processing // queue. List triggerActions = getActionsTriggeredBy(event); if (triggerActions != null && !triggerActions.isEmpty()) { for (Action triggerAction : triggerActions) { - stateManager.addActionTask(createActionTask(key, triggerAction, event)); + actionTasksKState.add(createActionTask(key, triggerAction, event)); } } } @@ -375,9 +421,9 @@ private void processActionTaskForKey(Object key) throws Exception { // 1. Get an action task for the key. setCurrentKey(key); - ActionTask actionTask = stateManager.pollNextActionTask(); + ActionTask actionTask = pollFromListState(actionTasksKState); if (actionTask == null) { - int removedCount = stateManager.removeProcessingKey(key); + int removedCount = removeFromListState(currentProcessingKeysOpState, key); checkState( removedCount == 1, "Current processing key count for key " @@ -394,12 +440,13 @@ private void processActionTaskForKey(Object key) throws Exception { // 2. Invoke the action task. createAndSetRunnerContext(actionTask, key); - long sequenceNumber = stateManager.getSequenceNumber(); + long sequenceNumber = sequenceNumberKState.value(); boolean isFinished; List outputEvents; Optional generatedActionTaskOpt = Optional.empty(); ActionState actionState = - maybeGetActionState(key, sequenceNumber, actionTask.action, actionTask.event); + durableExecManager.maybeGetActionState( + key, sequenceNumber, actionTask.action, actionTask.event); // Check if action is already completed if (actionState != null && actionState.isCompleted()) { @@ -426,14 +473,16 @@ private void processActionTaskForKey(Object key) throws Exception { } else { // Initialize ActionState if not exists, or use existing one for recovery if (actionState == null) { - maybeInitActionState(key, sequenceNumber, actionTask.action, actionTask.event); + durableExecManager.maybeInitActionState( + key, sequenceNumber, actionTask.action, actionTask.event); actionState = - maybeGetActionState( + durableExecManager.maybeGetActionState( key, sequenceNumber, actionTask.action, actionTask.event); } // Set up durable execution context for fine-grained recovery - setupDurableExecutionContext(actionTask, actionState); + durableExecManager.setupDurableExecutionContext( + actionTask, actionState, sequenceNumber); ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke( @@ -444,10 +493,10 @@ private void processActionTaskForKey(Object key) throws Exception { // back later if the action task has a generated action task, meaning it is not // finished. actionTaskMemoryContexts.remove(actionTask); - actionTaskDurableContexts.remove(actionTask); - continuationContexts.remove(actionTask); - pythonAwaitableRefs.remove(actionTask); - maybePersistTaskResult( + durableExecManager.removeDurableContext(actionTask); + durableExecManager.removeContinuationContext(actionTask); + durableExecManager.removePythonAwaitableRef(actionTask); + durableExecManager.maybePersistTaskResult( key, sequenceNumber, actionTask.action, @@ -466,7 +515,7 @@ private void processActionTaskForKey(Object key) throws Exception { boolean currentInputEventFinished = false; if (isFinished) { builtInMetrics.markActionExecuted(actionTask.action.getName()); - currentInputEventFinished = !stateManager.hasMoreActionTasks(); + currentInputEventFinished = !currentKeyHasMoreActionTask(); // Persist memory to the Flink state when the action task is finished. actionTask.getRunnerContext().persistMemory(); @@ -486,10 +535,10 @@ private void processActionTaskForKey(Object key) throws Exception { RunnerContextImpl.DurableExecutionContext durableContext = actionTask.getRunnerContext().getDurableExecutionContext(); if (durableContext != null) { - actionTaskDurableContexts.put(generatedActionTask, durableContext); + durableExecManager.putDurableContext(generatedActionTask, durableContext); } if (actionTask.getRunnerContext() instanceof JavaRunnerContextImpl) { - continuationContexts.put( + durableExecManager.putContinuationContext( generatedActionTask, ((JavaRunnerContextImpl) actionTask.getRunnerContext()) .getContinuationContext()); @@ -499,11 +548,11 @@ private void processActionTaskForKey(Object key) throws Exception { ((PythonRunnerContextImpl) actionTask.getRunnerContext()) .getPythonAwaitableRef(); if (awaitableRef != null) { - pythonAwaitableRefs.put(generatedActionTask, awaitableRef); + durableExecManager.putPythonAwaitableRef(generatedActionTask, awaitableRef); } } - stateManager.addActionTask(generatedActionTask); + actionTasksKState.add(generatedActionTask); } // 3. Process the next InputEvent or next action task @@ -513,8 +562,8 @@ private void processActionTaskForKey(Object key) throws Exception { // Once all sub-events and actions related to the current InputEvent are completed, // we can proceed to process the next InputEvent. - int removedCount = stateManager.removeProcessingKey(key); - maybePruneState(key, sequenceNumber); + int removedCount = removeFromListState(currentProcessingKeysOpState, key); + durableExecManager.maybePruneState(key, sequenceNumber); checkState( removedCount == 1, "Current processing key count for key " @@ -525,11 +574,11 @@ private void processActionTaskForKey(Object key) throws Exception { keySegmentQueue.removeKey(key), "Current key" + key + " is missing from the segmentedQueue."); processEligibleWatermarks(); - Event pendingInputEvent = stateManager.pollNextPendingInputEvent(); + Event pendingInputEvent = pollFromListState(pendingInputEventsKState); if (pendingInputEvent != null) { processEvent(key, pendingInputEvent); } - } else if (stateManager.hasMoreActionTasks()) { + } else if (currentKeyHasMoreActionTask()) { // If the current key has additional action tasks remaining, we should submit a new mail // to continue processing them. mailboxExecutor.submit(() -> tryProcessActionTaskForKey(key), "process action task"); @@ -584,7 +633,7 @@ private void initPythonEnvironment() throws Exception { this::checkMailboxThread, this.agentPlan, this.resourceCache, - stateManager.getJobIdentifier()); + this.jobIdentifier); javaResourceAdapter = new JavaResourceAdapter(this::getResource, pythonInterpreter); if (containPythonResource) { @@ -603,7 +652,7 @@ private void initPythonActionExecutor() throws Exception { agentPlan, javaResourceAdapter, pythonRunnerContext, - stateManager.getJobIdentifier()); + jobIdentifier); pythonActionExecutor.open(); } @@ -631,7 +680,7 @@ public void endInput() throws Exception { @VisibleForTesting public void waitInFlightEventsFinished() throws Exception { - while (stateManager.hasProcessingKeys()) { + while (listStateNotEmpty(currentProcessingKeysOpState)) { mailboxExecutor.yield(); } } @@ -661,8 +710,8 @@ public void close() throws Exception { if (eventLogger != null) { eventLogger.close(); } - if (actionStateStore != null) { - actionStateStore.close(); + if (durableExecManager != null) { + durableExecManager.close(); } if (continuationActionExecutor != null) { continuationActionExecutor.close(); @@ -675,58 +724,39 @@ public void close() throws Exception { public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - maybeInitActionStateStore(); - - if (actionStateStore != null) { - List markers = new ArrayList<>(); - - // We use UnionList here to ensure that the task can access all the recovery marker - // after - // parallelism modifications. - // The ActionStateStore will decide how to use the recovery markers. - ListState recoveryMarkerOpState = - getOperatorStateBackend() - .getUnionListState( - new ListStateDescriptor<>( - RECOVERY_MARKER_STATE_NAME, - TypeInformation.of(Object.class))); - - Iterable recoveryMarkers = recoveryMarkerOpState.get(); - if (recoveryMarkers != null) { - recoveryMarkers.forEach(markers::add); - } - LOG.info("Rebuilding action state from {} recovery markers", markers.size()); - actionStateStore.rebuildState(markers); - } + durableExecManager.maybeInitActionStateStore(agentPlan.getConfig()); + durableExecManager.handleRecovery(getOperatorStateBackend()); - stateManager.initJobIdentifier(context, agentPlan, getRuntimeContext()); + // Get job identifier from user configuration. + // If not configured, get from state. + jobIdentifier = agentPlan.getConfig().get(JOB_IDENTIFIER); + if (jobIdentifier == null) { + String initialJobIdentifier = getRuntimeContext().getJobInfo().getJobId().toString(); + jobIdentifier = + StateUtils.getSingleValueFromState( + context, "identifier_state", String.class, initialJobIdentifier); + } } @Override public void snapshotState(StateSnapshotContext context) throws Exception { - if (actionStateStore != null) { - Object recoveryMarker = actionStateStore.getRecoveryMarker(); - if (recoveryMarker != null) { - recoveryMarkerOpState.update(List.of(recoveryMarker)); - } - } + durableExecManager.snapshotRecoveryMarker(); - stateManager.snapshotSequenceNumbers( - getKeyedStateBackend(), checkpointIdToSeqNums, context.getCheckpointId()); + HashMap keyToSeqNum = new HashMap<>(); + getKeyedStateBackend() + .applyToAllKeys( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class), + (key, state) -> keyToSeqNum.put(key, state.value())); + durableExecManager.recordCheckpointSequenceNumbers(context.getCheckpointId(), keyToSeqNum); super.snapshotState(context); } @Override public void notifyCheckpointComplete(long checkpointId) throws Exception { - if (actionStateStore != null) { - Map keyToSeqNum = - checkpointIdToSeqNums.getOrDefault(checkpointId, new HashMap<>()); - for (Map.Entry entry : keyToSeqNum.entrySet()) { - actionStateStore.pruneState(entry.getKey(), entry.getValue()); - } - checkpointIdToSeqNums.remove(checkpointId); - } + durableExecManager.notifyCheckpointComplete(checkpointId); super.notifyCheckpointComplete(checkpointId); } @@ -805,8 +835,8 @@ private void createAndSetRunnerContext(ActionTask actionTask, Object key) { } else { memoryContext = new RunnerContextImpl.MemoryContext( - new CachedMemoryStore(stateManager.getSensoryMemState()), - new CachedMemoryStore(stateManager.getShortTermMemState())); + new CachedMemoryStore(sensoryMemState), + new CachedMemoryStore(shortTermMemState)); } runnerContext.switchActionContext( @@ -814,10 +844,10 @@ private void createAndSetRunnerContext(ActionTask actionTask, Object key) { if (runnerContext instanceof JavaRunnerContextImpl) { ContinuationContext continuationContext; - if (continuationContexts.containsKey(actionTask)) { + if (durableExecManager.hasContinuationContext(actionTask)) { // action task for async execution action, should retrieve intermediate results from // map. - continuationContext = continuationContexts.get(actionTask); + continuationContext = durableExecManager.getContinuationContext(actionTask); } else { continuationContext = new ContinuationContext(); } @@ -826,22 +856,24 @@ private void createAndSetRunnerContext(ActionTask actionTask, Object key) { if (runnerContext instanceof PythonRunnerContextImpl) { // Get the awaitable ref from the transient map. After checkpoint restore, this will be // null, signaling that the awaitable was lost and needs re-execution. - String awaitableRef = pythonAwaitableRefs.get(actionTask); + String awaitableRef = durableExecManager.getPythonAwaitableRef(actionTask); ((PythonRunnerContextImpl) runnerContext).setPythonAwaitableRef(awaitableRef); } actionTask.setRunnerContext(runnerContext); } + private boolean currentKeyHasMoreActionTask() throws Exception { + return listStateNotEmpty(actionTasksKState); + } + private void tryResumeProcessActionTasks() throws Exception { - Iterable keys = stateManager.getProcessingKeys(); + Iterable keys = currentProcessingKeysOpState.get(); if (keys != null) { int maxParallelism = getRuntimeContext().getTaskInfo().getMaxNumberOfParallelSubtasks(); KeyGroupRange currentSubtaskKeyGroupRange = - stateManager.getCurrentSubtaskKeyGroupRange( - maxParallelism, getRuntimeContext()); + getCurrentSubtaskKeyGroupRange(maxParallelism); for (Object key : keys) { - if (!stateManager.isKeyOwnedByCurrentSubtask( - key, maxParallelism, currentSubtaskKeyGroupRange)) { + if (!isKeyOwnedByCurrentSubtask(key, maxParallelism, currentSubtaskKeyGroupRange)) { continue; } keySegmentQueue.addKeyToLastSegment(key); @@ -850,129 +882,25 @@ private void tryResumeProcessActionTasks() throws Exception { } } - stateManager.forEachPendingInputEventKey( - getKeyedStateBackend(), - (key, state) -> - state.get().forEach(event -> keySegmentQueue.addKeyToLastSegment(key))); - } - - private ActionState maybeGetActionState( - Object key, long sequenceNum, Action action, Event event) throws Exception { - return actionStateStore == null - ? null - : actionStateStore.get(key.toString(), sequenceNum, action, event); - } - - private void maybeInitActionState(Object key, long sequenceNum, Action action, Event event) - throws Exception { - if (actionStateStore != null) { - // Initialize the action state if it does not exist. It will exist when the action is an - // async action and - // has been persisted before the action task is finished. - if (actionStateStore.get(key, sequenceNum, action, event) == null) { - actionStateStore.put(key, sequenceNum, action, event, new ActionState(event)); - } - } - } - - private void maybePersistTaskResult( - Object key, - long sequenceNum, - Action action, - Event event, - RunnerContextImpl context, - ActionTask.ActionTaskResult actionTaskResult) - throws Exception { - if (actionStateStore == null) { - return; - } - - // if the task is not finished, we skip the persistence for now and wait until it is - // finished. - if (!actionTaskResult.isFinished()) { - return; - } - - ActionState actionState = actionStateStore.get(key, sequenceNum, action, event); - - for (MemoryUpdate memoryUpdate : context.getSensoryMemoryUpdates()) { - actionState.addSensoryMemoryUpdate(memoryUpdate); - } - - for (MemoryUpdate memoryUpdate : context.getShortTermMemoryUpdates()) { - actionState.addShortTermMemoryUpdate(memoryUpdate); - } - - for (Event outputEvent : actionTaskResult.getOutputEvents()) { - actionState.addEvent(outputEvent); - } - - // Mark the action as completed and clear call records - // This indicates that recovery should skip the entire action - actionState.markCompleted(); - - actionStateStore.put(key, sequenceNum, action, event, actionState); - - // Clear durable execution context - context.clearDurableExecutionContext(); + getKeyedStateBackend() + .applyToAllKeys( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ListStateDescriptor<>( + PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)), + (key, state) -> + state.get() + .forEach( + event -> keySegmentQueue.addKeyToLastSegment(key))); } - /** - * Sets up the durable execution context for fine-grained recovery. - * - *

This method initializes the runner context with a {@link - * RunnerContextImpl.DurableExecutionContext}, which enables execute/execute_async calls to: - * - *

    - *
  • Skip re-execution for already completed calls during recovery - *
  • Persist CallRecords after each code block completion - *
- */ - private void setupDurableExecutionContext(ActionTask actionTask, ActionState actionState) { - if (actionStateStore == null) { - return; - } - - RunnerContextImpl.DurableExecutionContext durableContext; - if (actionTaskDurableContexts.containsKey(actionTask)) { - // Reuse existing context for async action continuation - durableContext = actionTaskDurableContexts.get(actionTask); + private void initOrIncSequenceNumber() throws Exception { + // Initialize the sequence number state if it does not exist. + Long sequenceNumber = sequenceNumberKState.value(); + if (sequenceNumber == null) { + sequenceNumberKState.update(0L); } else { - // Create new context for first invocation - final long sequenceNumber; - try { - sequenceNumber = stateManager.getSequenceNumber(); - } catch (Exception e) { - throw new RuntimeException("Failed to get sequence number from state", e); - } - - durableContext = - new RunnerContextImpl.DurableExecutionContext( - actionTask.getKey(), - sequenceNumber, - actionTask.action, - actionTask.event, - actionState, - this); - } - - actionTask.getRunnerContext().setDurableExecutionContext(durableContext); - } - - @Override - public void persist( - Object key, long sequenceNumber, Action action, Event event, ActionState actionState) { - try { - actionStateStore.put(key, sequenceNumber, action, event, actionState); - } catch (Exception e) { - LOG.error("Failed to persist ActionState", e); - throw new RuntimeException("Failed to persist ActionState", e); - } - } - - private void maybePruneState(Object key, long sequenceNum) throws Exception { - if (actionStateStore != null) { - actionStateStore.pruneState(key, sequenceNum); + sequenceNumberKState.update(sequenceNumber + 1); } } @@ -1000,7 +928,7 @@ private RunnerContextImpl createOrGetRunnerContext(Boolean isJava) { this::checkMailboxThread, this.agentPlan, this.resourceCache, - stateManager.getJobIdentifier(), + this.jobIdentifier, continuationActionExecutor); } return runnerContext; @@ -1012,7 +940,7 @@ private RunnerContextImpl createOrGetRunnerContext(Boolean isJava) { this::checkMailboxThread, this.agentPlan, this.resourceCache, - stateManager.getJobIdentifier()); + jobIdentifier); } return pythonRunnerContext; } @@ -1029,13 +957,22 @@ private EventLogger createEventLogger(AgentPlan agentPlan) { return EventLoggerFactory.createLogger(loggerConfigBuilder.build()); } - private void maybeInitActionStateStore() { - if (actionStateStore == null - && KAFKA.getType() - .equalsIgnoreCase(agentPlan.getConfig().get(ACTION_STATE_STORE_BACKEND))) { - LOG.info("Using Kafka as backend of action state store."); - actionStateStore = new KafkaActionStateStore(agentPlan.getConfig()); - } + @VisibleForTesting + DurableExecutionManager getDurableExecutionManager() { + return durableExecManager; + } + + private KeyGroupRange getCurrentSubtaskKeyGroupRange(int maxParallelism) { + int parallelism = getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks(); + int subtaskIndex = getRuntimeContext().getTaskInfo().getIndexOfThisSubtask(); + return KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( + maxParallelism, parallelism, subtaskIndex); + } + + private boolean isKeyOwnedByCurrentSubtask( + Object key, int maxParallelism, KeyGroupRange currentSubtaskKeyGroupRange) { + int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism); + return currentSubtaskKeyGroupRange.contains(keyGroup); } /** Failed to execute Action task. */ diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java new file mode 100644 index 000000000..0be24052b --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.context.MemoryUpdate; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.actionstate.ActionState; +import org.apache.flink.agents.runtime.actionstate.ActionStateStore; +import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore; +import org.apache.flink.agents.runtime.async.ContinuationContext; +import org.apache.flink.agents.runtime.context.ActionStatePersister; +import org.apache.flink.agents.runtime.context.RunnerContextImpl; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND; +import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA; + +class DurableExecutionManager implements ActionStatePersister, AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(DurableExecutionManager.class); + + private static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker"; + + private ActionStateStore actionStateStore; + private transient ListState recoveryMarkerOpState; + private final Map> checkpointIdToSeqNums; + + private final Map + actionTaskDurableContexts; + private final Map continuationContexts; + private final Map pythonAwaitableRefs; + + DurableExecutionManager(@Nullable ActionStateStore actionStateStore) { + this.actionStateStore = actionStateStore; + this.checkpointIdToSeqNums = new HashMap<>(); + this.actionTaskDurableContexts = new HashMap<>(); + this.continuationContexts = new HashMap<>(); + this.pythonAwaitableRefs = new HashMap<>(); + } + + void maybeInitActionStateStore(AgentConfiguration config) { + if (actionStateStore == null + && KAFKA.getType().equalsIgnoreCase(config.get(ACTION_STATE_STORE_BACKEND))) { + LOG.info("Using Kafka as backend of action state store."); + actionStateStore = new KafkaActionStateStore(config); + } + } + + boolean hasDurableStore() { + return actionStateStore != null; + } + + void initRecoveryMarkerState(OperatorStateBackend operatorStateBackend) throws Exception { + if (actionStateStore != null) { + recoveryMarkerOpState = + operatorStateBackend.getUnionListState( + new ListStateDescriptor<>( + RECOVERY_MARKER_STATE_NAME, TypeInformation.of(Object.class))); + } + } + + // Note: Re-creates the union list state descriptor here because handleRecovery() is called + // from initializeState() which runs BEFORE open(), so recoveryMarkerOpState is not yet + // initialized. The descriptor name matches exactly, so Flink returns the same state. + void handleRecovery(OperatorStateBackend operatorStateBackend) throws Exception { + if (actionStateStore != null) { + List markers = new ArrayList<>(); + ListState markerState = + operatorStateBackend.getUnionListState( + new ListStateDescriptor<>( + RECOVERY_MARKER_STATE_NAME, TypeInformation.of(Object.class))); + Iterable recoveryMarkers = markerState.get(); + if (recoveryMarkers != null) { + recoveryMarkers.forEach(markers::add); + } + LOG.info("Rebuilding action state from {} recovery markers", markers.size()); + actionStateStore.rebuildState(markers); + } + } + + ActionState maybeGetActionState(Object key, long sequenceNum, Action action, Event event) + throws Exception { + return actionStateStore == null + ? null + : actionStateStore.get(key.toString(), sequenceNum, action, event); + } + + void maybeInitActionState(Object key, long sequenceNum, Action action, Event event) + throws Exception { + if (actionStateStore != null) { + if (actionStateStore.get(key, sequenceNum, action, event) == null) { + actionStateStore.put(key, sequenceNum, action, event, new ActionState(event)); + } + } + } + + void maybePersistTaskResult( + Object key, + long sequenceNum, + Action action, + Event event, + RunnerContextImpl context, + ActionTask.ActionTaskResult actionTaskResult) + throws Exception { + if (actionStateStore == null) { + return; + } + + if (!actionTaskResult.isFinished()) { + return; + } + + ActionState actionState = actionStateStore.get(key, sequenceNum, action, event); + + for (MemoryUpdate memoryUpdate : context.getSensoryMemoryUpdates()) { + actionState.addSensoryMemoryUpdate(memoryUpdate); + } + + for (MemoryUpdate memoryUpdate : context.getShortTermMemoryUpdates()) { + actionState.addShortTermMemoryUpdate(memoryUpdate); + } + + for (Event outputEvent : actionTaskResult.getOutputEvents()) { + actionState.addEvent(outputEvent); + } + + actionState.markCompleted(); + + actionStateStore.put(key, sequenceNum, action, event, actionState); + + context.clearDurableExecutionContext(); + } + + void setupDurableExecutionContext(ActionTask actionTask, ActionState actionState, long seqNum) { + if (actionStateStore == null) { + return; + } + + RunnerContextImpl.DurableExecutionContext durableContext; + if (actionTaskDurableContexts.containsKey(actionTask)) { + durableContext = actionTaskDurableContexts.get(actionTask); + } else { + durableContext = + new RunnerContextImpl.DurableExecutionContext( + actionTask.getKey(), + seqNum, + actionTask.action, + actionTask.event, + actionState, + this); + } + + actionTask.getRunnerContext().setDurableExecutionContext(durableContext); + } + + @Override + public void persist( + Object key, long sequenceNumber, Action action, Event event, ActionState actionState) { + try { + actionStateStore.put(key, sequenceNumber, action, event, actionState); + } catch (Exception e) { + LOG.error("Failed to persist ActionState", e); + throw new RuntimeException("Failed to persist ActionState", e); + } + } + + void maybePruneState(Object key, long sequenceNum) throws Exception { + if (actionStateStore != null) { + actionStateStore.pruneState(key, sequenceNum); + } + } + + void notifyCheckpointComplete(long checkpointId) { + if (actionStateStore != null) { + Map keyToSeqNum = + checkpointIdToSeqNums.getOrDefault(checkpointId, new HashMap<>()); + for (Map.Entry entry : keyToSeqNum.entrySet()) { + actionStateStore.pruneState(entry.getKey(), entry.getValue()); + } + checkpointIdToSeqNums.remove(checkpointId); + } + } + + void snapshotRecoveryMarker() throws Exception { + if (actionStateStore != null) { + Object recoveryMarker = actionStateStore.getRecoveryMarker(); + if (recoveryMarker != null) { + recoveryMarkerOpState.update(List.of(recoveryMarker)); + } + } + } + + void recordCheckpointSequenceNumbers(long checkpointId, Map seqNums) { + checkpointIdToSeqNums.put(checkpointId, seqNums); + } + + // --- Context map accessors --- + + RunnerContextImpl.DurableExecutionContext getDurableContext(ActionTask actionTask) { + return actionTaskDurableContexts.get(actionTask); + } + + void putDurableContext( + ActionTask actionTask, RunnerContextImpl.DurableExecutionContext context) { + actionTaskDurableContexts.put(actionTask, context); + } + + void removeDurableContext(ActionTask actionTask) { + actionTaskDurableContexts.remove(actionTask); + } + + ContinuationContext getContinuationContext(ActionTask actionTask) { + return continuationContexts.get(actionTask); + } + + void putContinuationContext(ActionTask actionTask, ContinuationContext context) { + continuationContexts.put(actionTask, context); + } + + void removeContinuationContext(ActionTask actionTask) { + continuationContexts.remove(actionTask); + } + + String getPythonAwaitableRef(ActionTask actionTask) { + return pythonAwaitableRefs.get(actionTask); + } + + void putPythonAwaitableRef(ActionTask actionTask, String ref) { + pythonAwaitableRefs.put(actionTask, ref); + } + + void removePythonAwaitableRef(ActionTask actionTask) { + pythonAwaitableRefs.remove(actionTask); + } + + boolean hasDurableContext(ActionTask actionTask) { + return actionTaskDurableContexts.containsKey(actionTask); + } + + boolean hasContinuationContext(ActionTask actionTask) { + return continuationContexts.containsKey(actionTask); + } + + @VisibleForTesting + ActionStateStore getActionStateStore() { + return actionStateStore; + } + + @Override + public void close() throws Exception { + if (actionStateStore != null) { + actionStateStore.close(); + } + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java index a45a52d4a..b5e9fd641 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java @@ -252,12 +252,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - // Use reflection to access the action state store for validation - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); assertThat(actionStateStore).isNotNull(); assertThat(actionStateStore.getKeyedActionStates()).isEmpty(); @@ -346,12 +343,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - // Use reflection to access the action state store for validation - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); Long inputValue = 3L; testHarness.processElement(new StreamRecord<>(inputValue)); @@ -421,12 +415,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - // Access the action state store - java.lang.reflect.Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); // Process multiple elements with same key to test state persistence testHarness.processElement(new StreamRecord<>(1L)); @@ -490,12 +481,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(true)), (List>) testHarness.getRecordOutput(); assertThat(recordOutput.size()).isEqualTo(3); - // Access the action state store - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); assertThat(actionStateStore.getKeyedActionStates()).isEmpty(); } } @@ -514,11 +502,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - // Access the action state store - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); - actionStateStore = (InMemoryActionStateStore) actionStateStoreField.get(operator); + actionStateStore = + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); Long inputValue = 7L; From fda9e45b1e60e56a78b7bd208f845d28ab652918 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Sun, 5 Apr 2026 23:28:01 -0700 Subject: [PATCH 03/14] [runtime] Extract ActionTaskContextManager, EventRouter, and PythonBridgeManager Extract the remaining 3 manager classes from ActionExecutionOperator: - ActionTaskContextManager: runner context creation, memory contexts, continuation executor - EventRouter: event wrapping/routing, notification, watermarks, logging - PythonBridgeManager: Python env, interpreter, executor, resource adapters Test reflection for eventLogger updated to use @VisibleForTesting getter. Part of #545. --- .../operator/ActionExecutionOperator.java | 428 +++--------------- .../operator/ActionTaskContextManager.java | 189 ++++++++ .../agents/runtime/operator/EventRouter.java | 175 +++++++ .../runtime/operator/PythonBridgeManager.java | 177 ++++++++ .../operator/ActionExecutionOperatorTest.java | 4 +- 5 files changed, 614 insertions(+), 359 deletions(-) create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java create mode 100644 runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 87a59a68f..da09c6eee 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -18,45 +18,23 @@ package org.apache.flink.agents.runtime.operator; import org.apache.flink.agents.api.Event; -import org.apache.flink.agents.api.EventContext; -import org.apache.flink.agents.api.InputEvent; import org.apache.flink.agents.api.OutputEvent; import org.apache.flink.agents.api.agents.AgentExecutionOptions; import org.apache.flink.agents.api.context.MemoryUpdate; -import org.apache.flink.agents.api.listener.EventListener; -import org.apache.flink.agents.api.logger.EventLogger; -import org.apache.flink.agents.api.logger.EventLoggerConfig; -import org.apache.flink.agents.api.logger.EventLoggerFactory; -import org.apache.flink.agents.api.logger.EventLoggerOpenParams; -import org.apache.flink.agents.api.resource.Resource; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.plan.AgentPlan; import org.apache.flink.agents.plan.JavaFunction; import org.apache.flink.agents.plan.PythonFunction; import org.apache.flink.agents.plan.actions.Action; -import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; -import org.apache.flink.agents.runtime.PythonMCPResourceDiscovery; import org.apache.flink.agents.runtime.ResourceCache; import org.apache.flink.agents.runtime.actionstate.ActionState; import org.apache.flink.agents.runtime.actionstate.ActionStateStore; -import org.apache.flink.agents.runtime.async.ContinuationActionExecutor; -import org.apache.flink.agents.runtime.async.ContinuationContext; import org.apache.flink.agents.runtime.context.JavaRunnerContextImpl; import org.apache.flink.agents.runtime.context.RunnerContextImpl; -import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; -import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; -import org.apache.flink.agents.runtime.eventlog.FileEventLogger; -import org.apache.flink.agents.runtime.memory.CachedMemoryStore; import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; -import org.apache.flink.agents.runtime.operator.queue.SegmentedQueue; import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; -import org.apache.flink.agents.runtime.python.event.PythonEvent; import org.apache.flink.agents.runtime.python.operator.PythonActionTask; -import org.apache.flink.agents.runtime.python.utils.JavaResourceAdapter; -import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; -import org.apache.flink.agents.runtime.python.utils.PythonResourceAdapterImpl; import org.apache.flink.agents.runtime.utils.EventUtil; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.operators.MailboxExecutor; @@ -67,7 +45,6 @@ import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.python.env.PythonDependencyInfo; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.StateInitializationContext; @@ -80,29 +57,22 @@ import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.Output; -import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxExecutorImpl; import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxProcessor; -import org.apache.flink.types.Row; import org.apache.flink.util.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import pemja.core.PythonInterpreter; import java.lang.reflect.Field; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Optional; -import static org.apache.flink.agents.api.configuration.AgentConfigOptions.BASE_LOG_DIR; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.JOB_IDENTIFIER; -import static org.apache.flink.agents.api.configuration.AgentConfigOptions.PRETTY_PRINT; import static org.apache.flink.agents.runtime.utils.StateUtil.*; import static org.apache.flink.util.Preconditions.checkState; @@ -132,38 +102,19 @@ public class ActionExecutionOperator extends AbstractStreamOperator reusedStreamRecord; - private transient MapState sensoryMemState; private transient MapState shortTermMemState; - private transient PythonEnvironmentManager pythonEnvironmentManager; - - private transient PythonInterpreter pythonInterpreter; - - // PythonActionExecutor for Python actions - private transient PythonActionExecutor pythonActionExecutor; - - // RunnerContext for Python actions - private transient PythonRunnerContextImpl pythonRunnerContext; - - // PythonResourceAdapter for Python resources in Java actions - private transient PythonResourceAdapterImpl pythonResourceAdapter; - - // PythonResourceAdapter for Java resources in Python actions or Python resources - private transient JavaResourceAdapter javaResourceAdapter; + private transient PythonBridgeManager pythonBridge; private transient FlinkAgentsMetricGroupImpl metricGroup; private transient BuiltInMetrics builtInMetrics; - private transient SegmentedQueue keySegmentQueue; - private final transient MailboxExecutor mailboxExecutor; - // RunnerContext for Java Actions - private transient RunnerContextImpl runnerContext; + private transient ActionTaskContextManager contextManager; // We need to check whether the current thread is the mailbox thread using the mailbox // processor. @@ -185,18 +136,12 @@ public class ActionExecutionOperator extends AbstractStreamOperator currentProcessingKeysOpState; - private final transient EventLogger eventLogger; - private final transient List eventListeners; + private final transient EventRouter eventRouter; private transient ValueState sequenceNumberKState; private final transient DurableExecutionManager durableExecManager; - // This in memory map keep track of the runner context for the async action task that having - // been finished - private final transient Map - actionTaskMemoryContexts; - // Each job can only have one identifier and this identifier must be consistent across restarts. // We cannot use job id as the identifier here because user may change job id by // creating a savepoint, stop the job and then resume from savepoint. @@ -204,8 +149,6 @@ public class ActionExecutionOperator extends AbstractStreamOperator(); + this.eventRouter = new EventRouter<>(agentPlan, inputIsJava); this.durableExecManager = new DurableExecutionManager(actionStateStore); - this.actionTaskMemoryContexts = new HashMap<>(); OperatorUtils.setChainStrategy(this, ChainingStrategy.ALWAYS); } @@ -234,7 +175,6 @@ public void setup( @Override public void open() throws Exception { super.open(); - reusedStreamRecord = new StreamRecord<>(null); // init sensoryMemState MapStateDescriptor sensoryMemStateDescriptor = new MapStateDescriptor<>( @@ -256,7 +196,7 @@ public void open() throws Exception { metricGroup = new FlinkAgentsMetricGroupImpl(getMetricGroup()); builtInMetrics = new BuiltInMetrics(metricGroup, agentPlan); - keySegmentQueue = new SegmentedQueue(); + eventRouter.open(builtInMetrics); durableExecManager.maybeInitActionStateStore(agentPlan.getConfig()); durableExecManager.initRecoveryMarkerState(getOperatorStateBackend()); @@ -290,17 +230,27 @@ public void open() throws Exception { "currentProcessingKeys", TypeInformation.of(Object.class))); // init PythonActionExecutor and PythonResourceAdapter - initPythonEnvironment(); - - // init executor for Java async execution - continuationActionExecutor = - new ContinuationActionExecutor( + pythonBridge = new PythonBridgeManager(); + pythonBridge.open( + agentPlan, + resourceCache, + getExecutionConfig(), + getRuntimeContext().getDistributedCache(), + getContainingTask().getEnvironment().getTaskManagerInfo().getTmpDirectories(), + getRuntimeContext().getJobInfo().getJobId(), + metricGroup, + this::checkMailboxThread, + jobIdentifier); + + // init context manager for runner context creation and memory contexts + contextManager = + new ActionTaskContextManager( agentPlan.getConfig().get(AgentExecutionOptions.NUM_ASYNC_THREADS)); mailboxProcessor = getMailboxProcessor(); // Initialize the event logger if it is set. - initEventLogger(getRuntimeContext()); + eventRouter.initEventLogger(getRuntimeContext()); // Since an operator restart may change the key range it manages due to changes in // parallelism, @@ -309,17 +259,10 @@ public void open() throws Exception { tryResumeProcessActionTasks(); } - private void initEventLogger(StreamingRuntimeContext runtimeContext) throws Exception { - if (eventLogger == null) { - return; - } - eventLogger.open(new EventLoggerOpenParams(runtimeContext)); - } - @Override public void processWatermark(Watermark mark) throws Exception { - keySegmentQueue.addWatermark(mark); - processEligibleWatermarks(); + eventRouter.getKeySegmentQueue().addWatermark(mark); + eventRouter.processEligibleWatermarks(super::processWatermark); } @Override @@ -328,12 +271,13 @@ public void processElement(StreamRecord record) throws Exception { LOG.debug("Receive an element {}", input); // wrap to InputEvent first - Event inputEvent = wrapToInputEvent(input); + Event inputEvent = + eventRouter.wrapToInputEvent(input, pythonBridge.getPythonActionExecutor()); if (record.hasTimestamp()) { inputEvent.setSourceTimestamp(record.getTimestamp()); } - keySegmentQueue.addKeyToLastSegment(getCurrentKey()); + eventRouter.getKeySegmentQueue().addKeyToLastSegment(getCurrentKey()); if (currentKeyHasMoreActionTask()) { // If there are already actions being processed for the current key, the newly incoming @@ -351,17 +295,22 @@ public void processElement(StreamRecord record) throws Exception { * `tryProcessActionTaskForKey` to continue processing. */ private void processEvent(Object key, Event event) throws Exception { - notifyEventProcessed(event); + eventRouter.notifyEventProcessed(event); boolean isInputEvent = EventUtil.isInputEvent(event); if (EventUtil.isOutputEvent(event)) { // If the event is an OutputEvent, we send it downstream. - OUT outputData = getOutputFromOutputEvent(event); + OUT outputData = + eventRouter.getOutputFromOutputEvent( + event, pythonBridge.getPythonActionExecutor()); if (event.hasSourceTimestamp()) { - output.collect(reusedStreamRecord.replace(outputData, event.getSourceTimestamp())); + output.collect( + eventRouter + .getReusedStreamRecord() + .replace(outputData, event.getSourceTimestamp())); } else { - reusedStreamRecord.eraseTimestamp(); - output.collect(reusedStreamRecord.replace(outputData)); + eventRouter.getReusedStreamRecord().eraseTimestamp(); + output.collect(eventRouter.getReusedStreamRecord().replace(outputData)); } } else { if (isInputEvent) { @@ -371,7 +320,7 @@ private void processEvent(Object key, Event event) throws Exception { } // We then obtain the triggered action and add ActionTasks to the waiting processing // queue. - List triggerActions = getActionsTriggeredBy(event); + List triggerActions = eventRouter.getActionsTriggeredBy(event, agentPlan); if (triggerActions != null && !triggerActions.isEmpty()) { for (Action triggerAction : triggerActions) { actionTasksKState.add(createActionTask(key, triggerAction, event)); @@ -385,25 +334,6 @@ private void processEvent(Object key, Event event) throws Exception { } } - private void notifyEventProcessed(Event event) throws Exception { - EventContext eventContext = new EventContext(event); - if (eventLogger != null) { - // If event logging is enabled, we log the event along with its context. - eventLogger.append(eventContext, event); - // For now, we flush the event logger after each event to ensure immediate logging. - // This is a temporary solution to ensure that events are logged immediately. - // TODO: In the future, we may want to implement a more efficient batching mechanism. - eventLogger.flush(); - } - if (eventListeners != null) { - // Notify all registered event listeners about the event. - for (EventListener listener : eventListeners) { - listener.onEventProcessed(eventContext, event); - } - } - builtInMetrics.markEventProcessed(); - } - private void tryProcessActionTaskForKey(Object key) { try { processActionTaskForKey(key); @@ -431,14 +361,25 @@ private void processActionTaskForKey(Object key) throws Exception { + " should be 1, but got " + removedCount); checkState( - keySegmentQueue.removeKey(key), + eventRouter.getKeySegmentQueue().removeKey(key), "Current key" + key + " is missing from the segmentedQueue."); - processEligibleWatermarks(); + eventRouter.processEligibleWatermarks(super::processWatermark); return; } // 2. Invoke the action task. - createAndSetRunnerContext(actionTask, key); + contextManager.createAndSetRunnerContext( + actionTask, + key, + agentPlan, + resourceCache, + metricGroup, + jobIdentifier, + this::checkMailboxThread, + sensoryMemState, + shortTermMemState, + pythonBridge.getPythonRunnerContext(), + durableExecManager); long sequenceNumber = sequenceNumberKState.value(); boolean isFinished; @@ -487,12 +428,12 @@ private void processActionTaskForKey(Object key) throws Exception { ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke( getRuntimeContext().getUserCodeClassLoader(), - this.pythonActionExecutor); + this.pythonBridge.getPythonActionExecutor()); // We remove the contexts from the map after the task is processed. They will be added // back later if the action task has a generated action task, meaning it is not // finished. - actionTaskMemoryContexts.remove(actionTask); + contextManager.removeMemoryContext(actionTask); durableExecManager.removeDurableContext(actionTask); durableExecManager.removeContinuationContext(actionTask); durableExecManager.removePythonAwaitableRef(actionTask); @@ -530,7 +471,7 @@ private void processActionTaskForKey(Object key) throws Exception { // If the action task is not finished, we keep the contexts in memory for the // next generated ActionTask to be invoked. - actionTaskMemoryContexts.put( + contextManager.putMemoryContext( generatedActionTask, actionTask.getRunnerContext().getMemoryContext()); RunnerContextImpl.DurableExecutionContext durableContext = actionTask.getRunnerContext().getDurableExecutionContext(); @@ -571,9 +512,9 @@ private void processActionTaskForKey(Object key) throws Exception { + " should be 1, but got " + removedCount); checkState( - keySegmentQueue.removeKey(key), + eventRouter.getKeySegmentQueue().removeKey(key), "Current key" + key + " is missing from the segmentedQueue."); - processEligibleWatermarks(); + eventRouter.processEligibleWatermarks(super::processWatermark); Event pendingInputEvent = pollFromListState(pendingInputEventsKState); if (pendingInputEvent != null) { processEvent(key, pendingInputEvent); @@ -585,94 +526,6 @@ private void processActionTaskForKey(Object key) throws Exception { } } - private Resource getResource(String name, ResourceType type) { - try { - return resourceCache.getResource(name, type); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - private void initPythonEnvironment() throws Exception { - boolean containPythonAction = - agentPlan.getActions().values().stream() - .anyMatch(action -> action.getExec() instanceof PythonFunction); - - boolean containPythonResource = - agentPlan.getResourceProviders().values().stream() - .anyMatch( - resourceProviderMap -> - resourceProviderMap.values().stream() - .anyMatch( - resourceProvider -> - resourceProvider - instanceof - PythonResourceProvider)); - - if (containPythonAction || containPythonResource) { - LOG.debug("Begin initialize PythonEnvironmentManager."); - PythonDependencyInfo dependencyInfo = - PythonDependencyInfo.create( - getExecutionConfig().toConfiguration(), - getRuntimeContext().getDistributedCache()); - pythonEnvironmentManager = - new PythonEnvironmentManager( - dependencyInfo, - getContainingTask() - .getEnvironment() - .getTaskManagerInfo() - .getTmpDirectories(), - new HashMap<>(System.getenv()), - getRuntimeContext().getJobInfo().getJobId()); - pythonEnvironmentManager.open(); - EmbeddedPythonEnvironment env = pythonEnvironmentManager.createEnvironment(); - pythonInterpreter = env.getInterpreter(); - pythonRunnerContext = - new PythonRunnerContextImpl( - this.metricGroup, - this::checkMailboxThread, - this.agentPlan, - this.resourceCache, - this.jobIdentifier); - - javaResourceAdapter = new JavaResourceAdapter(this::getResource, pythonInterpreter); - if (containPythonResource) { - initPythonResourceAdapter(); - } - if (containPythonAction) { - initPythonActionExecutor(); - } - } - } - - private void initPythonActionExecutor() throws Exception { - pythonActionExecutor = - new PythonActionExecutor( - pythonInterpreter, - agentPlan, - javaResourceAdapter, - pythonRunnerContext, - jobIdentifier); - pythonActionExecutor.open(); - } - - private void initPythonResourceAdapter() throws Exception { - pythonResourceAdapter = - new PythonResourceAdapterImpl( - (String anotherName, ResourceType anotherType) -> { - try { - return resourceCache.getResource(anotherName, anotherType); - } catch (Exception e) { - throw new RuntimeException(e); - } - }, - pythonInterpreter, - javaResourceAdapter); - pythonResourceAdapter.open(); - PythonMCPResourceDiscovery.discoverPythonMCPResources( - agentPlan.getResourceProviders(), pythonResourceAdapter, resourceCache); - } - @Override public void endInput() throws Exception { waitInFlightEventsFinished(); @@ -691,31 +544,18 @@ public void close() throws Exception { if (resourceCache != null) { resourceCache.close(); } - if (runnerContext != null) { - try { - runnerContext.close(); - } finally { - runnerContext = null; - } - } - if (pythonActionExecutor != null) { - pythonActionExecutor.close(); + if (contextManager != null) { + contextManager.close(); } - if (pythonInterpreter != null) { - pythonInterpreter.close(); + if (pythonBridge != null) { + pythonBridge.close(); } - if (pythonEnvironmentManager != null) { - pythonEnvironmentManager.close(); - } - if (eventLogger != null) { - eventLogger.close(); + if (eventRouter != null) { + eventRouter.close(); } if (durableExecManager != null) { durableExecManager.close(); } - if (continuationActionExecutor != null) { - continuationActionExecutor.close(); - } super.close(); } @@ -760,39 +600,6 @@ public void notifyCheckpointComplete(long checkpointId) throws Exception { super.notifyCheckpointComplete(checkpointId); } - private Event wrapToInputEvent(IN input) { - if (inputIsJava) { - return new InputEvent(input); - } else { - // the input data must originate from Python and be of type Row with two fields — the - // first representing the key, and the second representing the actual data payload. - checkState(input instanceof Row && ((Row) input).getArity() == 2); - return pythonActionExecutor.wrapToInputEvent(((Row) input).getField(1)); - } - } - - private OUT getOutputFromOutputEvent(Event event) { - checkState(EventUtil.isOutputEvent(event)); - if (event instanceof OutputEvent) { - return (OUT) ((OutputEvent) event).getOutput(); - } else if (event instanceof PythonEvent) { - Object outputFromOutputEvent = - pythonActionExecutor.getOutputFromOutputEvent(((PythonEvent) event).getEvent()); - return (OUT) outputFromOutputEvent; - } else { - throw new IllegalStateException( - "Unsupported event type: " + event.getClass().getName()); - } - } - - private List getActionsTriggeredBy(Event event) { - if (event instanceof PythonEvent) { - return agentPlan.getActionsTriggeredBy(((PythonEvent) event).getEventType()); - } else { - return agentPlan.getActionsTriggeredBy(event.getClass().getName()); - } - } - private MailboxProcessor getMailboxProcessor() throws Exception { Field field = MailboxExecutorImpl.class.getDeclaredField("mailboxProcessor"); field.setAccessible(true); @@ -816,52 +623,6 @@ private ActionTask createActionTask(Object key, Action action, Event event) { } } - private void createAndSetRunnerContext(ActionTask actionTask, Object key) { - RunnerContextImpl runnerContext; - if (actionTask.action.getExec() instanceof JavaFunction) { - runnerContext = createOrGetRunnerContext(true); - } else if (actionTask.action.getExec() instanceof PythonFunction) { - runnerContext = createOrGetRunnerContext(false); - } else { - throw new IllegalStateException( - "Unsupported action type: " + actionTask.action.getExec().getClass()); - } - - RunnerContextImpl.MemoryContext memoryContext; - if (actionTaskMemoryContexts.containsKey(actionTask)) { - // action task for async execution action, should retrieve intermediate results from - // map. - memoryContext = actionTaskMemoryContexts.get(actionTask); - } else { - memoryContext = - new RunnerContextImpl.MemoryContext( - new CachedMemoryStore(sensoryMemState), - new CachedMemoryStore(shortTermMemState)); - } - - runnerContext.switchActionContext( - actionTask.action.getName(), memoryContext, String.valueOf(key.hashCode())); - - if (runnerContext instanceof JavaRunnerContextImpl) { - ContinuationContext continuationContext; - if (durableExecManager.hasContinuationContext(actionTask)) { - // action task for async execution action, should retrieve intermediate results from - // map. - continuationContext = durableExecManager.getContinuationContext(actionTask); - } else { - continuationContext = new ContinuationContext(); - } - ((JavaRunnerContextImpl) runnerContext).setContinuationContext(continuationContext); - } - if (runnerContext instanceof PythonRunnerContextImpl) { - // Get the awaitable ref from the transient map. After checkpoint restore, this will be - // null, signaling that the awaitable was lost and needs re-execution. - String awaitableRef = durableExecManager.getPythonAwaitableRef(actionTask); - ((PythonRunnerContextImpl) runnerContext).setPythonAwaitableRef(awaitableRef); - } - actionTask.setRunnerContext(runnerContext); - } - private boolean currentKeyHasMoreActionTask() throws Exception { return listStateNotEmpty(actionTasksKState); } @@ -876,7 +637,7 @@ private void tryResumeProcessActionTasks() throws Exception { if (!isKeyOwnedByCurrentSubtask(key, maxParallelism, currentSubtaskKeyGroupRange)) { continue; } - keySegmentQueue.addKeyToLastSegment(key); + eventRouter.getKeySegmentQueue().addKeyToLastSegment(key); mailboxExecutor.submit( () -> tryProcessActionTaskForKey(key), "process action task"); } @@ -891,7 +652,10 @@ private void tryResumeProcessActionTasks() throws Exception { (key, state) -> state.get() .forEach( - event -> keySegmentQueue.addKeyToLastSegment(key))); + event -> + eventRouter + .getKeySegmentQueue() + .addKeyToLastSegment(key))); } private void initOrIncSequenceNumber() throws Exception { @@ -904,64 +668,16 @@ private void initOrIncSequenceNumber() throws Exception { } } - private void processEligibleWatermarks() throws Exception { - Watermark mark = keySegmentQueue.popOldestWatermark(); - while (mark != null) { - super.processWatermark(mark); - mark = keySegmentQueue.popOldestWatermark(); - } - } - - private RunnerContextImpl createOrGetRunnerContext(Boolean isJava) { - if (isJava) { - if (runnerContext == null) { - if (continuationActionExecutor == null) { - continuationActionExecutor = - new ContinuationActionExecutor( - agentPlan - .getConfig() - .get(AgentExecutionOptions.NUM_ASYNC_THREADS)); - } - runnerContext = - new JavaRunnerContextImpl( - this.metricGroup, - this::checkMailboxThread, - this.agentPlan, - this.resourceCache, - this.jobIdentifier, - continuationActionExecutor); - } - return runnerContext; - } else { - if (pythonRunnerContext == null) { - pythonRunnerContext = - new PythonRunnerContextImpl( - this.metricGroup, - this::checkMailboxThread, - this.agentPlan, - this.resourceCache, - jobIdentifier); - } - return pythonRunnerContext; - } - } - - private EventLogger createEventLogger(AgentPlan agentPlan) { - EventLoggerConfig.Builder loggerConfigBuilder = EventLoggerConfig.builder(); - String baseLogDir = agentPlan.getConfig().get(BASE_LOG_DIR); - if (baseLogDir != null && !baseLogDir.trim().isEmpty()) { - loggerConfigBuilder.property(FileEventLogger.BASE_LOG_DIR_PROPERTY_KEY, baseLogDir); - } - loggerConfigBuilder.property( - FileEventLogger.PRETTY_PRINT_PROPERTY_KEY, agentPlan.getConfig().get(PRETTY_PRINT)); - return EventLoggerFactory.createLogger(loggerConfigBuilder.build()); - } - @VisibleForTesting DurableExecutionManager getDurableExecutionManager() { return durableExecManager; } + @VisibleForTesting + EventRouter getEventRouter() { + return eventRouter; + } + private KeyGroupRange getCurrentSubtaskKeyGroupRange(int maxParallelism) { int parallelism = getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks(); int subtaskIndex = getRuntimeContext().getTaskInfo().getIndexOfThisSubtask(); diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java new file mode 100644 index 000000000..0d990a3a8 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.agents.plan.PythonFunction; +import org.apache.flink.agents.runtime.ResourceCache; +import org.apache.flink.agents.runtime.async.ContinuationActionExecutor; +import org.apache.flink.agents.runtime.async.ContinuationContext; +import org.apache.flink.agents.runtime.context.JavaRunnerContextImpl; +import org.apache.flink.agents.runtime.context.RunnerContextImpl; +import org.apache.flink.agents.runtime.memory.CachedMemoryStore; +import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; +import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; +import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; +import org.apache.flink.api.common.state.MapState; + +import java.util.HashMap; +import java.util.Map; + +class ActionTaskContextManager implements AutoCloseable { + + private RunnerContextImpl runnerContext; + + private final Map actionTaskMemoryContexts; + + private ContinuationActionExecutor continuationActionExecutor; + + ActionTaskContextManager(int numAsyncThreads) { + this.actionTaskMemoryContexts = new HashMap<>(); + this.continuationActionExecutor = new ContinuationActionExecutor(numAsyncThreads); + } + + RunnerContextImpl createOrGetRunnerContext( + boolean isJava, + AgentPlan agentPlan, + ResourceCache resourceCache, + FlinkAgentsMetricGroupImpl metricGroup, + String jobIdentifier, + Runnable mailboxThreadChecker, + PythonRunnerContextImpl pythonRunnerContext) { + if (isJava) { + if (runnerContext == null) { + if (continuationActionExecutor == null) { + throw new IllegalStateException( + "ContinuationActionExecutor has not been initialized."); + } + runnerContext = + new JavaRunnerContextImpl( + metricGroup, + mailboxThreadChecker, + agentPlan, + resourceCache, + jobIdentifier, + continuationActionExecutor); + } + return runnerContext; + } else { + if (pythonRunnerContext == null) { + throw new IllegalStateException( + "PythonRunnerContextImpl has not been initialized."); + } + return pythonRunnerContext; + } + } + + void createAndSetRunnerContext( + ActionTask actionTask, + Object key, + AgentPlan agentPlan, + ResourceCache resourceCache, + FlinkAgentsMetricGroupImpl metricGroup, + String jobIdentifier, + Runnable mailboxThreadChecker, + MapState sensoryMemState, + MapState shortTermMemState, + PythonRunnerContextImpl pythonRunnerContext, + DurableExecutionManager durableExecManager) { + RunnerContextImpl context; + if (actionTask.action.getExec() instanceof JavaFunction) { + context = + createOrGetRunnerContext( + true, + agentPlan, + resourceCache, + metricGroup, + jobIdentifier, + mailboxThreadChecker, + pythonRunnerContext); + } else if (actionTask.action.getExec() instanceof PythonFunction) { + context = + createOrGetRunnerContext( + false, + agentPlan, + resourceCache, + metricGroup, + jobIdentifier, + mailboxThreadChecker, + pythonRunnerContext); + } else { + throw new IllegalStateException( + "Unsupported action type: " + actionTask.action.getExec().getClass()); + } + + RunnerContextImpl.MemoryContext memoryContext; + if (actionTaskMemoryContexts.containsKey(actionTask)) { + memoryContext = actionTaskMemoryContexts.get(actionTask); + } else { + memoryContext = + new RunnerContextImpl.MemoryContext( + new CachedMemoryStore(sensoryMemState), + new CachedMemoryStore(shortTermMemState)); + } + + context.switchActionContext( + actionTask.action.getName(), memoryContext, String.valueOf(key.hashCode())); + + if (context instanceof JavaRunnerContextImpl) { + ContinuationContext continuationContext; + if (durableExecManager.hasContinuationContext(actionTask)) { + // action task for async execution action, should retrieve intermediate results + // from map. + continuationContext = durableExecManager.getContinuationContext(actionTask); + } else { + continuationContext = new ContinuationContext(); + } + ((JavaRunnerContextImpl) context).setContinuationContext(continuationContext); + } + if (context instanceof PythonRunnerContextImpl) { + // Get the awaitable ref from the transient map. After checkpoint restore, this will + // be null, signaling that the awaitable was lost and needs re-execution. + String awaitableRef = durableExecManager.getPythonAwaitableRef(actionTask); + ((PythonRunnerContextImpl) context).setPythonAwaitableRef(awaitableRef); + } + actionTask.setRunnerContext(context); + } + + Resource getResource(String name, ResourceType type, ResourceCache resourceCache) { + try { + return resourceCache.getResource(name, type); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + RunnerContextImpl.MemoryContext getMemoryContext(ActionTask actionTask) { + return actionTaskMemoryContexts.get(actionTask); + } + + void putMemoryContext(ActionTask actionTask, RunnerContextImpl.MemoryContext memoryContext) { + actionTaskMemoryContexts.put(actionTask, memoryContext); + } + + RunnerContextImpl.MemoryContext removeMemoryContext(ActionTask actionTask) { + return actionTaskMemoryContexts.remove(actionTask); + } + + @Override + public void close() throws Exception { + if (runnerContext != null) { + try { + runnerContext.close(); + } finally { + runnerContext = null; + } + } + if (continuationActionExecutor != null) { + continuationActionExecutor.close(); + } + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java new file mode 100644 index 000000000..e3cbd7f39 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.EventContext; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.listener.EventListener; +import org.apache.flink.agents.api.logger.EventLogger; +import org.apache.flink.agents.api.logger.EventLoggerConfig; +import org.apache.flink.agents.api.logger.EventLoggerFactory; +import org.apache.flink.agents.api.logger.EventLoggerOpenParams; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.eventlog.FileEventLogger; +import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; +import org.apache.flink.agents.runtime.operator.queue.SegmentedQueue; +import org.apache.flink.agents.runtime.python.event.PythonEvent; +import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; +import org.apache.flink.agents.runtime.utils.EventUtil; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.types.Row; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.BASE_LOG_DIR; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.PRETTY_PRINT; +import static org.apache.flink.util.Preconditions.checkState; + +class EventRouter implements AutoCloseable { + + private final boolean inputIsJava; + private final EventLogger eventLogger; + private final List eventListeners; + private StreamRecord reusedStreamRecord; + private SegmentedQueue keySegmentQueue; + private BuiltInMetrics builtInMetrics; + + EventRouter(AgentPlan agentPlan, boolean inputIsJava) { + this.inputIsJava = inputIsJava; + this.eventLogger = createEventLogger(agentPlan); + this.eventListeners = new ArrayList<>(); + } + + void open(BuiltInMetrics builtInMetrics) { + this.reusedStreamRecord = new StreamRecord<>(null); + this.keySegmentQueue = new SegmentedQueue(); + this.builtInMetrics = builtInMetrics; + } + + void initEventLogger(StreamingRuntimeContext runtimeContext) throws Exception { + if (eventLogger == null) { + return; + } + eventLogger.open(new EventLoggerOpenParams(runtimeContext)); + } + + @SuppressWarnings("unchecked") + Event wrapToInputEvent(IN input, PythonActionExecutor pythonActionExecutor) { + if (inputIsJava) { + return new InputEvent(input); + } else { + // the input data must originate from Python and be of type Row with two fields — the + // first representing the key, and the second representing the actual data payload. + checkState(input instanceof Row && ((Row) input).getArity() == 2); + return pythonActionExecutor.wrapToInputEvent(((Row) input).getField(1)); + } + } + + @SuppressWarnings("unchecked") + OUT getOutputFromOutputEvent(Event event, PythonActionExecutor pythonActionExecutor) { + checkState(EventUtil.isOutputEvent(event)); + if (event instanceof OutputEvent) { + return (OUT) ((OutputEvent) event).getOutput(); + } else if (event instanceof PythonEvent) { + Object outputFromOutputEvent = + pythonActionExecutor.getOutputFromOutputEvent(((PythonEvent) event).getEvent()); + return (OUT) outputFromOutputEvent; + } else { + throw new IllegalStateException( + "Unsupported event type: " + event.getClass().getName()); + } + } + + List getActionsTriggeredBy(Event event, AgentPlan agentPlan) { + if (event instanceof PythonEvent) { + return agentPlan.getActionsTriggeredBy(((PythonEvent) event).getEventType()); + } else { + return agentPlan.getActionsTriggeredBy(event.getClass().getName()); + } + } + + void notifyEventProcessed(Event event) throws Exception { + EventContext eventContext = new EventContext(event); + if (eventLogger != null) { + // If event logging is enabled, we log the event along with its context. + eventLogger.append(eventContext, event); + // For now, we flush the event logger after each event to ensure immediate logging. + // This is a temporary solution to ensure that events are logged immediately. + // TODO: In the future, we may want to implement a more efficient batching mechanism. + eventLogger.flush(); + } + if (eventListeners != null) { + // Notify all registered event listeners about the event. + for (EventListener listener : eventListeners) { + listener.onEventProcessed(eventContext, event); + } + } + builtInMetrics.markEventProcessed(); + } + + void processEligibleWatermarks(WatermarkEmitter watermarkEmitter) throws Exception { + Watermark mark = keySegmentQueue.popOldestWatermark(); + while (mark != null) { + watermarkEmitter.emit(mark); + mark = keySegmentQueue.popOldestWatermark(); + } + } + + SegmentedQueue getKeySegmentQueue() { + return keySegmentQueue; + } + + StreamRecord getReusedStreamRecord() { + return reusedStreamRecord; + } + + @VisibleForTesting + EventLogger getEventLogger() { + return eventLogger; + } + + private EventLogger createEventLogger(AgentPlan agentPlan) { + EventLoggerConfig.Builder loggerConfigBuilder = EventLoggerConfig.builder(); + String baseLogDir = agentPlan.getConfig().get(BASE_LOG_DIR); + if (baseLogDir != null && !baseLogDir.trim().isEmpty()) { + loggerConfigBuilder.property(FileEventLogger.BASE_LOG_DIR_PROPERTY_KEY, baseLogDir); + } + loggerConfigBuilder.property( + FileEventLogger.PRETTY_PRINT_PROPERTY_KEY, agentPlan.getConfig().get(PRETTY_PRINT)); + return EventLoggerFactory.createLogger(loggerConfigBuilder.build()); + } + + @Override + public void close() throws Exception { + if (eventLogger != null) { + eventLogger.close(); + } + } + + @FunctionalInterface + interface WatermarkEmitter { + void emit(Watermark mark) throws Exception; + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java new file mode 100644 index 000000000..d23ecfb8b --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.PythonFunction; +import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; +import org.apache.flink.agents.runtime.PythonMCPResourceDiscovery; +import org.apache.flink.agents.runtime.ResourceCache; +import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; +import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; +import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; +import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; +import org.apache.flink.agents.runtime.python.utils.JavaResourceAdapter; +import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; +import org.apache.flink.agents.runtime.python.utils.PythonResourceAdapterImpl; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; +import org.apache.flink.python.env.PythonDependencyInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pemja.core.PythonInterpreter; + +import java.util.HashMap; + +class PythonBridgeManager implements AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(PythonBridgeManager.class); + + private PythonEnvironmentManager pythonEnvironmentManager; + private PythonInterpreter pythonInterpreter; + private PythonActionExecutor pythonActionExecutor; + private PythonRunnerContextImpl pythonRunnerContext; + private PythonResourceAdapterImpl pythonResourceAdapter; + private JavaResourceAdapter javaResourceAdapter; + private boolean initialized; + + PythonBridgeManager() { + this.initialized = false; + } + + void open( + AgentPlan agentPlan, + ResourceCache resourceCache, + ExecutionConfig executionConfig, + org.apache.flink.api.common.cache.DistributedCache distributedCache, + String[] tmpDirs, + JobID jobId, + FlinkAgentsMetricGroupImpl metricGroup, + Runnable mailboxThreadChecker, + String jobIdentifier) + throws Exception { + boolean containPythonAction = + agentPlan.getActions().values().stream() + .anyMatch(action -> action.getExec() instanceof PythonFunction); + + boolean containPythonResource = + agentPlan.getResourceProviders().values().stream() + .anyMatch( + resourceProviderMap -> + resourceProviderMap.values().stream() + .anyMatch( + resourceProvider -> + resourceProvider + instanceof + PythonResourceProvider)); + + if (containPythonAction || containPythonResource) { + LOG.debug("Begin initialize PythonEnvironmentManager."); + PythonDependencyInfo dependencyInfo = + PythonDependencyInfo.create( + executionConfig.toConfiguration(), distributedCache); + pythonEnvironmentManager = + new PythonEnvironmentManager( + dependencyInfo, tmpDirs, new HashMap<>(System.getenv()), jobId); + pythonEnvironmentManager.open(); + EmbeddedPythonEnvironment env = pythonEnvironmentManager.createEnvironment(); + pythonInterpreter = env.getInterpreter(); + pythonRunnerContext = + new PythonRunnerContextImpl( + metricGroup, + mailboxThreadChecker, + agentPlan, + resourceCache, + jobIdentifier); + + javaResourceAdapter = + new JavaResourceAdapter( + (name, type) -> { + try { + return resourceCache.getResource(name, type); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + pythonInterpreter); + if (containPythonResource) { + initPythonResourceAdapter(agentPlan, resourceCache); + } + if (containPythonAction) { + initPythonActionExecutor(agentPlan, jobIdentifier); + } + initialized = true; + } + } + + private void initPythonActionExecutor(AgentPlan agentPlan, String jobIdentifier) + throws Exception { + pythonActionExecutor = + new PythonActionExecutor( + pythonInterpreter, + agentPlan, + javaResourceAdapter, + pythonRunnerContext, + jobIdentifier); + pythonActionExecutor.open(); + } + + private void initPythonResourceAdapter(AgentPlan agentPlan, ResourceCache resourceCache) + throws Exception { + pythonResourceAdapter = + new PythonResourceAdapterImpl( + (String anotherName, ResourceType anotherType) -> { + try { + return resourceCache.getResource(anotherName, anotherType); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + pythonInterpreter, + javaResourceAdapter); + pythonResourceAdapter.open(); + PythonMCPResourceDiscovery.discoverPythonMCPResources( + agentPlan.getResourceProviders(), pythonResourceAdapter, resourceCache); + } + + PythonActionExecutor getPythonActionExecutor() { + return pythonActionExecutor; + } + + PythonRunnerContextImpl getPythonRunnerContext() { + return pythonRunnerContext; + } + + boolean isInitialized() { + return initialized; + } + + @Override + public void close() throws Exception { + if (pythonActionExecutor != null) { + pythonActionExecutor.close(); + } + if (pythonInterpreter != null) { + pythonInterpreter.close(); + } + if (pythonEnvironmentManager != null) { + pythonEnvironmentManager.close(); + } + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java index b5e9fd641..f5608b099 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java @@ -310,9 +310,7 @@ void testEventLogBaseDirFromAgentConfig() throws Exception { testHarness.open(); ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - Field eventLoggerField = ActionExecutionOperator.class.getDeclaredField("eventLogger"); - eventLoggerField.setAccessible(true); - Object eventLogger = eventLoggerField.get(operator); + Object eventLogger = operator.getEventRouter().getEventLogger(); assertThat(eventLogger).isInstanceOf(FileEventLogger.class); Field configField = FileEventLogger.class.getDeclaredField("config"); From f86b7cd908d709b44d9cc9d4cbf46c4cc9a0e6c7 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Sun, 5 Apr 2026 23:57:52 -0700 Subject: [PATCH 04/14] [runtime] Complete state extraction into OperatorStateManager Move all 7 state fields, 2 constants, and state methods from ActionExecutionOperator into OperatorStateManager. All state access now goes through stateManager delegation calls. Add transferContexts() to ActionTaskContextManager to encapsulate the context transfer logic for unfinished async action tasks. Operator shrinks from 700 to 562 lines. Part of #545. --- .../operator/ActionExecutionOperator.java | 230 ++++-------------- .../operator/ActionTaskContextManager.java | 26 ++ .../operator/OperatorStateManager.java | 7 +- 3 files changed, 74 insertions(+), 189 deletions(-) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index da09c6eee..34c408831 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -28,29 +28,15 @@ import org.apache.flink.agents.runtime.ResourceCache; import org.apache.flink.agents.runtime.actionstate.ActionState; import org.apache.flink.agents.runtime.actionstate.ActionStateStore; -import org.apache.flink.agents.runtime.context.JavaRunnerContextImpl; -import org.apache.flink.agents.runtime.context.RunnerContextImpl; -import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; -import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; import org.apache.flink.agents.runtime.python.operator.PythonActionTask; import org.apache.flink.agents.runtime.utils.EventUtil; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.operators.MailboxExecutor; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.state.MapState; -import org.apache.flink.api.common.state.MapStateDescriptor; -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.runtime.state.VoidNamespace; -import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.BoundedOneInput; @@ -68,12 +54,10 @@ import org.slf4j.LoggerFactory; import java.lang.reflect.Field; -import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; -import static org.apache.flink.agents.api.configuration.AgentConfigOptions.JOB_IDENTIFIER; -import static org.apache.flink.agents.runtime.utils.StateUtil.*; import static org.apache.flink.util.Preconditions.checkState; /** @@ -93,19 +77,12 @@ public class ActionExecutionOperator extends AbstractStreamOperator sensoryMemState; - - private transient MapState shortTermMemState; - private transient PythonBridgeManager pythonBridge; private transient FlinkAgentsMetricGroupImpl metricGroup; @@ -123,31 +100,11 @@ public class ActionExecutionOperator extends AbstractStreamOperator actionTasksKState; - - // To avoid processing different InputEvents with the same key, we use a state to store pending - // InputEvents that are waiting to be processed. - private transient ListState pendingInputEventsKState; - - // An operator state is used to track the currently processing keys. This is useful when - // receiving an EndOfInput signal, as we need to wait until all related events are fully - // processed. - private transient ListState currentProcessingKeysOpState; - private final transient EventRouter eventRouter; - private transient ValueState sequenceNumberKState; - private final transient DurableExecutionManager durableExecManager; - // Each job can only have one identifier and this identifier must be consistent across restarts. - // We cannot use job id as the identifier here because user may change job id by - // creating a savepoint, stop the job and then resume from savepoint. - // We use this identifier to control the visibility for long-term memory. - // Inspired by Apache Paimon. - private transient String jobIdentifier; + private transient OperatorStateManager stateManager; public ActionExecutionOperator( AgentPlan agentPlan, @@ -175,21 +132,10 @@ public void setup( @Override public void open() throws Exception { super.open(); - // init sensoryMemState - MapStateDescriptor sensoryMemStateDescriptor = - new MapStateDescriptor<>( - "sensoryMemory", - TypeInformation.of(String.class), - TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); - sensoryMemState = getRuntimeContext().getMapState(sensoryMemStateDescriptor); - - // init shortTermMemState - MapStateDescriptor shortTermMemStateDescriptor = - new MapStateDescriptor<>( - "shortTermMemory", - TypeInformation.of(String.class), - TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); - shortTermMemState = getRuntimeContext().getMapState(shortTermMemStateDescriptor); + + stateManager = new OperatorStateManager(); + stateManager.initializeKeyedStates(getRuntimeContext()); + stateManager.initializeOperatorStates(getOperatorStateBackend()); resourceCache = new ResourceCache(agentPlan.getResourceProviders()); @@ -200,34 +146,6 @@ public void open() throws Exception { durableExecManager.maybeInitActionStateStore(agentPlan.getConfig()); durableExecManager.initRecoveryMarkerState(getOperatorStateBackend()); - // init sequence number state for per key message ordering - sequenceNumberKState = - getRuntimeContext() - .getState( - new ValueStateDescriptor<>( - MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class)); - - // init agent processing related state - actionTasksKState = - getRuntimeContext() - .getListState( - new ListStateDescriptor<>( - "actionTasks", TypeInformation.of(ActionTask.class))); - pendingInputEventsKState = - getRuntimeContext() - .getListState( - new ListStateDescriptor<>( - PENDING_INPUT_EVENT_STATE_NAME, - TypeInformation.of(Event.class))); - // We use UnionList here to ensure that the task can access all keys after parallelism - // modifications. - // Subsequent steps {@link #tryResumeProcessActionTasks} will then filter out keys that do - // not belong to the key range of current task. - currentProcessingKeysOpState = - getOperatorStateBackend() - .getUnionListState( - new ListStateDescriptor<>( - "currentProcessingKeys", TypeInformation.of(Object.class))); // init PythonActionExecutor and PythonResourceAdapter pythonBridge = new PythonBridgeManager(); @@ -240,7 +158,7 @@ public void open() throws Exception { getRuntimeContext().getJobInfo().getJobId(), metricGroup, this::checkMailboxThread, - jobIdentifier); + stateManager.getJobIdentifier()); // init context manager for runner context creation and memory contexts contextManager = @@ -279,11 +197,11 @@ public void processElement(StreamRecord record) throws Exception { eventRouter.getKeySegmentQueue().addKeyToLastSegment(getCurrentKey()); - if (currentKeyHasMoreActionTask()) { + if (stateManager.hasMoreActionTasks()) { // If there are already actions being processed for the current key, the newly incoming // event should be queued and processed later. Therefore, we add it to // pendingInputEventsState. - pendingInputEventsKState.add(inputEvent); + stateManager.addPendingInputEvent(inputEvent); } else { // Otherwise, the new event is processed immediately. processEvent(getCurrentKey(), inputEvent); @@ -315,15 +233,15 @@ private void processEvent(Object key, Event event) throws Exception { } else { if (isInputEvent) { // If the event is an InputEvent, we mark that the key is currently being processed. - currentProcessingKeysOpState.add(key); - initOrIncSequenceNumber(); + stateManager.addProcessingKey(key); + stateManager.initOrIncSequenceNumber(); } // We then obtain the triggered action and add ActionTasks to the waiting processing // queue. List triggerActions = eventRouter.getActionsTriggeredBy(event, agentPlan); if (triggerActions != null && !triggerActions.isEmpty()) { for (Action triggerAction : triggerActions) { - actionTasksKState.add(createActionTask(key, triggerAction, event)); + stateManager.addActionTask(createActionTask(key, triggerAction, event)); } } } @@ -351,9 +269,9 @@ private void processActionTaskForKey(Object key) throws Exception { // 1. Get an action task for the key. setCurrentKey(key); - ActionTask actionTask = pollFromListState(actionTasksKState); + ActionTask actionTask = stateManager.pollNextActionTask(); if (actionTask == null) { - int removedCount = removeFromListState(currentProcessingKeysOpState, key); + int removedCount = stateManager.removeProcessingKey(key); checkState( removedCount == 1, "Current processing key count for key " @@ -374,14 +292,14 @@ private void processActionTaskForKey(Object key) throws Exception { agentPlan, resourceCache, metricGroup, - jobIdentifier, + stateManager.getJobIdentifier(), this::checkMailboxThread, - sensoryMemState, - shortTermMemState, + stateManager.getSensoryMemState(), + stateManager.getShortTermMemState(), pythonBridge.getPythonRunnerContext(), durableExecManager); - long sequenceNumber = sequenceNumberKState.value(); + long sequenceNumber = stateManager.getSequenceNumber(); boolean isFinished; List outputEvents; Optional generatedActionTaskOpt = Optional.empty(); @@ -456,7 +374,7 @@ private void processActionTaskForKey(Object key) throws Exception { boolean currentInputEventFinished = false; if (isFinished) { builtInMetrics.markActionExecuted(actionTask.action.getName()); - currentInputEventFinished = !currentKeyHasMoreActionTask(); + currentInputEventFinished = !stateManager.hasMoreActionTasks(); // Persist memory to the Flink state when the action task is finished. actionTask.getRunnerContext().persistMemory(); @@ -471,29 +389,9 @@ private void processActionTaskForKey(Object key) throws Exception { // If the action task is not finished, we keep the contexts in memory for the // next generated ActionTask to be invoked. - contextManager.putMemoryContext( - generatedActionTask, actionTask.getRunnerContext().getMemoryContext()); - RunnerContextImpl.DurableExecutionContext durableContext = - actionTask.getRunnerContext().getDurableExecutionContext(); - if (durableContext != null) { - durableExecManager.putDurableContext(generatedActionTask, durableContext); - } - if (actionTask.getRunnerContext() instanceof JavaRunnerContextImpl) { - durableExecManager.putContinuationContext( - generatedActionTask, - ((JavaRunnerContextImpl) actionTask.getRunnerContext()) - .getContinuationContext()); - } - if (actionTask.getRunnerContext() instanceof PythonRunnerContextImpl) { - String awaitableRef = - ((PythonRunnerContextImpl) actionTask.getRunnerContext()) - .getPythonAwaitableRef(); - if (awaitableRef != null) { - durableExecManager.putPythonAwaitableRef(generatedActionTask, awaitableRef); - } - } + contextManager.transferContexts(actionTask, generatedActionTask, durableExecManager); - actionTasksKState.add(generatedActionTask); + stateManager.addActionTask(generatedActionTask); } // 3. Process the next InputEvent or next action task @@ -503,7 +401,7 @@ private void processActionTaskForKey(Object key) throws Exception { // Once all sub-events and actions related to the current InputEvent are completed, // we can proceed to process the next InputEvent. - int removedCount = removeFromListState(currentProcessingKeysOpState, key); + int removedCount = stateManager.removeProcessingKey(key); durableExecManager.maybePruneState(key, sequenceNumber); checkState( removedCount == 1, @@ -515,11 +413,11 @@ private void processActionTaskForKey(Object key) throws Exception { eventRouter.getKeySegmentQueue().removeKey(key), "Current key" + key + " is missing from the segmentedQueue."); eventRouter.processEligibleWatermarks(super::processWatermark); - Event pendingInputEvent = pollFromListState(pendingInputEventsKState); + Event pendingInputEvent = stateManager.pollNextPendingInputEvent(); if (pendingInputEvent != null) { processEvent(key, pendingInputEvent); } - } else if (currentKeyHasMoreActionTask()) { + } else if (stateManager.hasMoreActionTasks()) { // If the current key has additional action tasks remaining, we should submit a new mail // to continue processing them. mailboxExecutor.submit(() -> tryProcessActionTaskForKey(key), "process action task"); @@ -533,7 +431,7 @@ public void endInput() throws Exception { @VisibleForTesting public void waitInFlightEventsFinished() throws Exception { - while (listStateNotEmpty(currentProcessingKeysOpState)) { + while (stateManager.hasProcessingKeys()) { mailboxExecutor.yield(); } } @@ -567,28 +465,16 @@ public void initializeState(StateInitializationContext context) throws Exception durableExecManager.maybeInitActionStateStore(agentPlan.getConfig()); durableExecManager.handleRecovery(getOperatorStateBackend()); - // Get job identifier from user configuration. - // If not configured, get from state. - jobIdentifier = agentPlan.getConfig().get(JOB_IDENTIFIER); - if (jobIdentifier == null) { - String initialJobIdentifier = getRuntimeContext().getJobInfo().getJobId().toString(); - jobIdentifier = - StateUtils.getSingleValueFromState( - context, "identifier_state", String.class, initialJobIdentifier); - } + stateManager = new OperatorStateManager(); + stateManager.initJobIdentifier(context, agentPlan, getRuntimeContext()); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { durableExecManager.snapshotRecoveryMarker(); - HashMap keyToSeqNum = new HashMap<>(); - getKeyedStateBackend() - .applyToAllKeys( - VoidNamespace.INSTANCE, - VoidNamespaceSerializer.INSTANCE, - new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class), - (key, state) -> keyToSeqNum.put(key, state.value())); + Map keyToSeqNum = + stateManager.snapshotSequenceNumbers(getKeyedStateBackend()); durableExecManager.recordCheckpointSequenceNumbers(context.getCheckpointId(), keyToSeqNum); super.snapshotState(context); @@ -623,18 +509,16 @@ private ActionTask createActionTask(Object key, Action action, Event event) { } } - private boolean currentKeyHasMoreActionTask() throws Exception { - return listStateNotEmpty(actionTasksKState); - } - private void tryResumeProcessActionTasks() throws Exception { - Iterable keys = currentProcessingKeysOpState.get(); + Iterable keys = stateManager.getProcessingKeys(); if (keys != null) { int maxParallelism = getRuntimeContext().getTaskInfo().getMaxNumberOfParallelSubtasks(); KeyGroupRange currentSubtaskKeyGroupRange = - getCurrentSubtaskKeyGroupRange(maxParallelism); + stateManager.getCurrentSubtaskKeyGroupRange( + maxParallelism, getRuntimeContext()); for (Object key : keys) { - if (!isKeyOwnedByCurrentSubtask(key, maxParallelism, currentSubtaskKeyGroupRange)) { + if (!stateManager.isKeyOwnedByCurrentSubtask( + key, maxParallelism, currentSubtaskKeyGroupRange)) { continue; } eventRouter.getKeySegmentQueue().addKeyToLastSegment(key); @@ -643,29 +527,15 @@ private void tryResumeProcessActionTasks() throws Exception { } } - getKeyedStateBackend() - .applyToAllKeys( - VoidNamespace.INSTANCE, - VoidNamespaceSerializer.INSTANCE, - new ListStateDescriptor<>( - PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)), - (key, state) -> - state.get() - .forEach( - event -> - eventRouter - .getKeySegmentQueue() - .addKeyToLastSegment(key))); - } - - private void initOrIncSequenceNumber() throws Exception { - // Initialize the sequence number state if it does not exist. - Long sequenceNumber = sequenceNumberKState.value(); - if (sequenceNumber == null) { - sequenceNumberKState.update(0L); - } else { - sequenceNumberKState.update(sequenceNumber + 1); - } + stateManager.forEachPendingInputEventKey( + getKeyedStateBackend(), + (key, state) -> + state.get() + .forEach( + event -> + eventRouter + .getKeySegmentQueue() + .addKeyToLastSegment(key))); } @VisibleForTesting @@ -678,17 +548,9 @@ EventRouter getEventRouter() { return eventRouter; } - private KeyGroupRange getCurrentSubtaskKeyGroupRange(int maxParallelism) { - int parallelism = getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks(); - int subtaskIndex = getRuntimeContext().getTaskInfo().getIndexOfThisSubtask(); - return KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( - maxParallelism, parallelism, subtaskIndex); - } - - private boolean isKeyOwnedByCurrentSubtask( - Object key, int maxParallelism, KeyGroupRange currentSubtaskKeyGroupRange) { - int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism); - return currentSubtaskKeyGroupRange.contains(keyGroup); + @VisibleForTesting + OperatorStateManager getOperatorStateManager() { + return stateManager; } /** Failed to execute Action task. */ diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java index 0d990a3a8..ca6074d61 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java @@ -173,6 +173,32 @@ RunnerContextImpl.MemoryContext removeMemoryContext(ActionTask actionTask) { return actionTaskMemoryContexts.remove(actionTask); } + /** + * Transfers memory, durable execution, continuation, and Python awaitable contexts from the + * completed action task to the generated (next) action task. + */ + void transferContexts( + ActionTask fromTask, ActionTask toTask, DurableExecutionManager durableExecManager) { + putMemoryContext(toTask, fromTask.getRunnerContext().getMemoryContext()); + RunnerContextImpl.DurableExecutionContext durableContext = + fromTask.getRunnerContext().getDurableExecutionContext(); + if (durableContext != null) { + durableExecManager.putDurableContext(toTask, durableContext); + } + if (fromTask.getRunnerContext() instanceof JavaRunnerContextImpl) { + durableExecManager.putContinuationContext( + toTask, + ((JavaRunnerContextImpl) fromTask.getRunnerContext()).getContinuationContext()); + } + if (fromTask.getRunnerContext() instanceof PythonRunnerContextImpl) { + String awaitableRef = + ((PythonRunnerContextImpl) fromTask.getRunnerContext()).getPythonAwaitableRef(); + if (awaitableRef != null) { + durableExecManager.putPythonAwaitableRef(toTask, awaitableRef); + } + } + } + @Override public void close() throws Exception { if (runnerContext != null) { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java index 0a20670f3..d99444006 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java @@ -197,10 +197,7 @@ boolean isKeyOwnedByCurrentSubtask( } @SuppressWarnings("unchecked") - void snapshotSequenceNumbers( - KeyedStateBackend keyedStateBackend, - Map> checkpointIdToSeqNums, - long checkpointId) + Map snapshotSequenceNumbers(KeyedStateBackend keyedStateBackend) throws Exception { HashMap keyToSeqNum = new HashMap<>(); ((KeyedStateBackend) keyedStateBackend) @@ -209,7 +206,7 @@ void snapshotSequenceNumbers( VoidNamespaceSerializer.INSTANCE, new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class), (key, state) -> keyToSeqNum.put(key, state.value())); - checkpointIdToSeqNums.put(checkpointId, keyToSeqNum); + return keyToSeqNum; } @SuppressWarnings("unchecked") From 78a1219b20ec164a9bdc07c82d65feb7d18fc921 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Mon, 6 Apr 2026 08:16:28 -0700 Subject: [PATCH 05/14] [runtime] Fix stateManager double-creation causing null jobIdentifier OperatorStateManager was created in both initializeState() and open(), causing the instance with jobIdentifier (set in initializeState) to be overwritten by a fresh instance in open(). This passed null jobIdentifier to PythonBridgeManager, causing VectorStoreLongTermMemory validation failure and visit_count memory scoping issues in e2e tests. Fix: create stateManager only in initializeState() (which runs first in Flink lifecycle), then initialize state descriptors in open(). Part of #545. --- .../flink/agents/runtime/operator/ActionExecutionOperator.java | 1 - 1 file changed, 1 deletion(-) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 34c408831..010f2d2a4 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -133,7 +133,6 @@ public void setup( public void open() throws Exception { super.open(); - stateManager = new OperatorStateManager(); stateManager.initializeKeyedStates(getRuntimeContext()); stateManager.initializeOperatorStates(getOperatorStateBackend()); From 300a0faec13262c4386a4652b1fab0e657283d7f Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Fri, 17 Apr 2026 01:14:27 -0700 Subject: [PATCH 06/14] [runtime] Remove unused getResource helper from ActionTaskContextManager Addresses review feedback on PR #546: the getResource method was never invoked. Callers go through RunnerContextImpl.getResource or directly through ResourceCache. Also removes the now-unused Resource/ResourceType imports. --- .../runtime/operator/ActionTaskContextManager.java | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java index ca6074d61..2684e1ee0 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java @@ -17,8 +17,6 @@ */ package org.apache.flink.agents.runtime.operator; -import org.apache.flink.agents.api.resource.Resource; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.plan.AgentPlan; import org.apache.flink.agents.plan.JavaFunction; import org.apache.flink.agents.plan.PythonFunction; @@ -153,14 +151,6 @@ void createAndSetRunnerContext( actionTask.setRunnerContext(context); } - Resource getResource(String name, ResourceType type, ResourceCache resourceCache) { - try { - return resourceCache.getResource(name, type); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - RunnerContextImpl.MemoryContext getMemoryContext(ActionTask actionTask) { return actionTaskMemoryContexts.get(actionTask); } From ab0c58143b2705cdd873d5d9b7a5d5e7f01f0aa1 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Tue, 28 Apr 2026 20:13:45 -0700 Subject: [PATCH 07/14] [runtime] Move continuation/awaitable maps from DEM to ATCM Address xintongsong's review feedback (#546 comments #6, #7): the continuationContexts and pythonAwaitableRefs maps are runner-context concerns, not durable-execution concerns, and their placement in DurableExecutionManager forced ActionTaskContextManager to hold a manager-to-manager reference forbidden by 545-DESIGN.md. Move both maps and their accessors into ActionTaskContextManager. actionTaskDurableContexts stays in DEM since DEM consumes it via setupDurableExecutionContext. Drop the DEM parameter from createAndSetRunnerContext (no longer needed); transferContexts keeps the DEM parameter for putDurableContext only. --- .../operator/ActionExecutionOperator.java | 7 ++- .../operator/ActionTaskContextManager.java | 45 ++++++++++++++++--- .../operator/DurableExecutionManager.java | 35 +-------------- 3 files changed, 42 insertions(+), 45 deletions(-) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 010f2d2a4..b4ab898e1 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -295,8 +295,7 @@ private void processActionTaskForKey(Object key) throws Exception { this::checkMailboxThread, stateManager.getSensoryMemState(), stateManager.getShortTermMemState(), - pythonBridge.getPythonRunnerContext(), - durableExecManager); + pythonBridge.getPythonRunnerContext()); long sequenceNumber = stateManager.getSequenceNumber(); boolean isFinished; @@ -352,8 +351,8 @@ private void processActionTaskForKey(Object key) throws Exception { // finished. contextManager.removeMemoryContext(actionTask); durableExecManager.removeDurableContext(actionTask); - durableExecManager.removeContinuationContext(actionTask); - durableExecManager.removePythonAwaitableRef(actionTask); + contextManager.removeContinuationContext(actionTask); + contextManager.removePythonAwaitableRef(actionTask); durableExecManager.maybePersistTaskResult( key, sequenceNumber, diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java index 2684e1ee0..0daf51962 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java @@ -39,11 +39,15 @@ class ActionTaskContextManager implements AutoCloseable { private RunnerContextImpl runnerContext; private final Map actionTaskMemoryContexts; + private final Map continuationContexts; + private final Map pythonAwaitableRefs; private ContinuationActionExecutor continuationActionExecutor; ActionTaskContextManager(int numAsyncThreads) { this.actionTaskMemoryContexts = new HashMap<>(); + this.continuationContexts = new HashMap<>(); + this.pythonAwaitableRefs = new HashMap<>(); this.continuationActionExecutor = new ContinuationActionExecutor(numAsyncThreads); } @@ -90,8 +94,7 @@ void createAndSetRunnerContext( Runnable mailboxThreadChecker, MapState sensoryMemState, MapState shortTermMemState, - PythonRunnerContextImpl pythonRunnerContext, - DurableExecutionManager durableExecManager) { + PythonRunnerContextImpl pythonRunnerContext) { RunnerContextImpl context; if (actionTask.action.getExec() instanceof JavaFunction) { context = @@ -133,10 +136,10 @@ void createAndSetRunnerContext( if (context instanceof JavaRunnerContextImpl) { ContinuationContext continuationContext; - if (durableExecManager.hasContinuationContext(actionTask)) { + if (this.hasContinuationContext(actionTask)) { // action task for async execution action, should retrieve intermediate results // from map. - continuationContext = durableExecManager.getContinuationContext(actionTask); + continuationContext = this.getContinuationContext(actionTask); } else { continuationContext = new ContinuationContext(); } @@ -145,7 +148,7 @@ void createAndSetRunnerContext( if (context instanceof PythonRunnerContextImpl) { // Get the awaitable ref from the transient map. After checkpoint restore, this will // be null, signaling that the awaitable was lost and needs re-execution. - String awaitableRef = durableExecManager.getPythonAwaitableRef(actionTask); + String awaitableRef = this.getPythonAwaitableRef(actionTask); ((PythonRunnerContextImpl) context).setPythonAwaitableRef(awaitableRef); } actionTask.setRunnerContext(context); @@ -176,7 +179,7 @@ void transferContexts( durableExecManager.putDurableContext(toTask, durableContext); } if (fromTask.getRunnerContext() instanceof JavaRunnerContextImpl) { - durableExecManager.putContinuationContext( + this.putContinuationContext( toTask, ((JavaRunnerContextImpl) fromTask.getRunnerContext()).getContinuationContext()); } @@ -184,11 +187,39 @@ void transferContexts( String awaitableRef = ((PythonRunnerContextImpl) fromTask.getRunnerContext()).getPythonAwaitableRef(); if (awaitableRef != null) { - durableExecManager.putPythonAwaitableRef(toTask, awaitableRef); + this.putPythonAwaitableRef(toTask, awaitableRef); } } } + ContinuationContext getContinuationContext(ActionTask actionTask) { + return continuationContexts.get(actionTask); + } + + void putContinuationContext(ActionTask actionTask, ContinuationContext context) { + continuationContexts.put(actionTask, context); + } + + void removeContinuationContext(ActionTask actionTask) { + continuationContexts.remove(actionTask); + } + + boolean hasContinuationContext(ActionTask actionTask) { + return continuationContexts.containsKey(actionTask); + } + + String getPythonAwaitableRef(ActionTask actionTask) { + return pythonAwaitableRefs.get(actionTask); + } + + void putPythonAwaitableRef(ActionTask actionTask, String ref) { + pythonAwaitableRefs.put(actionTask, ref); + } + + void removePythonAwaitableRef(ActionTask actionTask) { + pythonAwaitableRefs.remove(actionTask); + } + @Override public void close() throws Exception { if (runnerContext != null) { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java index 0be24052b..495002577 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java @@ -24,7 +24,6 @@ import org.apache.flink.agents.runtime.actionstate.ActionState; import org.apache.flink.agents.runtime.actionstate.ActionStateStore; import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore; -import org.apache.flink.agents.runtime.async.ContinuationContext; import org.apache.flink.agents.runtime.context.ActionStatePersister; import org.apache.flink.agents.runtime.context.RunnerContextImpl; import org.apache.flink.annotation.VisibleForTesting; @@ -57,15 +56,11 @@ class DurableExecutionManager implements ActionStatePersister, AutoCloseable { private final Map actionTaskDurableContexts; - private final Map continuationContexts; - private final Map pythonAwaitableRefs; DurableExecutionManager(@Nullable ActionStateStore actionStateStore) { this.actionStateStore = actionStateStore; this.checkpointIdToSeqNums = new HashMap<>(); this.actionTaskDurableContexts = new HashMap<>(); - this.continuationContexts = new HashMap<>(); - this.pythonAwaitableRefs = new HashMap<>(); } void maybeInitActionStateStore(AgentConfiguration config) { @@ -224,7 +219,7 @@ void recordCheckpointSequenceNumbers(long checkpointId, Map seqNum checkpointIdToSeqNums.put(checkpointId, seqNums); } - // --- Context map accessors --- + // --- Durable execution context map accessors --- RunnerContextImpl.DurableExecutionContext getDurableContext(ActionTask actionTask) { return actionTaskDurableContexts.get(actionTask); @@ -239,38 +234,10 @@ void removeDurableContext(ActionTask actionTask) { actionTaskDurableContexts.remove(actionTask); } - ContinuationContext getContinuationContext(ActionTask actionTask) { - return continuationContexts.get(actionTask); - } - - void putContinuationContext(ActionTask actionTask, ContinuationContext context) { - continuationContexts.put(actionTask, context); - } - - void removeContinuationContext(ActionTask actionTask) { - continuationContexts.remove(actionTask); - } - - String getPythonAwaitableRef(ActionTask actionTask) { - return pythonAwaitableRefs.get(actionTask); - } - - void putPythonAwaitableRef(ActionTask actionTask, String ref) { - pythonAwaitableRefs.put(actionTask, ref); - } - - void removePythonAwaitableRef(ActionTask actionTask) { - pythonAwaitableRefs.remove(actionTask); - } - boolean hasDurableContext(ActionTask actionTask) { return actionTaskDurableContexts.containsKey(actionTask); } - boolean hasContinuationContext(ActionTask actionTask) { - return continuationContexts.containsKey(actionTask); - } - @VisibleForTesting ActionStateStore getActionStateStore() { return actionStateStore; From 1139b9c708828807e0c3503c96647f6a532e728e Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Tue, 28 Apr 2026 20:35:33 -0700 Subject: [PATCH 08/14] [runtime] Remove dead inputIsJava field; move jobIdentifier to operator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address xintongsong's review feedback (#546 comments #3, #5): #5 — `inputIsJava` field on the operator was never read; the constructor parameter is forwarded directly to EventRouter. Remove the field; keep the constructor parameter. #3 — `jobIdentifier` is a runtime identity, not a Flink state descriptor. OperatorStateManager exists to own keyed/operator state; identity belongs on the operator. Move the `jobIdentifier` field, its initialization (formerly `OperatorStateManager.initJobIdentifier`), and the accessor into ActionExecutionOperator. Inline the init body into the operator's `initializeState()`. Preserve the savepoint state descriptor name `"identifier_state"` exactly to retain compatibility with existing savepoints. --- .../operator/ActionExecutionOperator.java | 24 ++++++++++++++----- .../operator/OperatorStateManager.java | 24 ------------------- 2 files changed, 18 insertions(+), 30 deletions(-) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index b4ab898e1..2d8933a3c 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -58,6 +58,7 @@ import java.util.Map; import java.util.Optional; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.JOB_IDENTIFIER; import static org.apache.flink.util.Preconditions.checkState; /** @@ -81,8 +82,6 @@ public class ActionExecutionOperator extends AbstractStreamOperator extends AbstractStreamOperator(agentPlan, inputIsJava); @@ -157,7 +157,7 @@ public void open() throws Exception { getRuntimeContext().getJobInfo().getJobId(), metricGroup, this::checkMailboxThread, - stateManager.getJobIdentifier()); + jobIdentifier); // init context manager for runner context creation and memory contexts contextManager = @@ -291,7 +291,7 @@ private void processActionTaskForKey(Object key) throws Exception { agentPlan, resourceCache, metricGroup, - stateManager.getJobIdentifier(), + jobIdentifier, this::checkMailboxThread, stateManager.getSensoryMemState(), stateManager.getShortTermMemState(), @@ -464,7 +464,19 @@ public void initializeState(StateInitializationContext context) throws Exception durableExecManager.handleRecovery(getOperatorStateBackend()); stateManager = new OperatorStateManager(); - stateManager.initJobIdentifier(context, agentPlan, getRuntimeContext()); + + // Resolve the agent's stable job identifier: + // - If the user set it via AgentConfigOptions.JOB_IDENTIFIER, use that. + // - Otherwise fall back to the current Flink JobID, cached in operator + // state so the value remains stable across job restarts (Flink + // generates a fresh JobID on each restart). + jobIdentifier = agentPlan.getConfig().get(JOB_IDENTIFIER); + if (jobIdentifier == null) { + String initialJobIdentifier = getRuntimeContext().getJobInfo().getJobId().toString(); + jobIdentifier = + StateUtils.getSingleValueFromState( + context, "identifier_state", String.class, initialJobIdentifier); + } } @Override diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java index d99444006..15bb3b500 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java @@ -19,7 +19,6 @@ package org.apache.flink.agents.runtime.operator; import org.apache.flink.agents.api.Event; -import org.apache.flink.agents.plan.AgentPlan; import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; @@ -33,14 +32,12 @@ import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.KeyedStateFunction; import org.apache.flink.runtime.state.OperatorStateBackend; -import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; import java.util.HashMap; import java.util.Map; -import static org.apache.flink.agents.api.configuration.AgentConfigOptions.JOB_IDENTIFIER; import static org.apache.flink.agents.runtime.utils.StateUtil.*; class OperatorStateManager { @@ -54,7 +51,6 @@ class OperatorStateManager { private ValueState sequenceNumberKState; private MapState sensoryMemState; private MapState shortTermMemState; - private String jobIdentifier; OperatorStateManager() {} @@ -103,22 +99,6 @@ void initializeOperatorStates(OperatorStateBackend operatorStateBackend) throws "currentProcessingKeys", TypeInformation.of(Object.class))); } - void initJobIdentifier( - StateInitializationContext context, - AgentPlan agentPlan, - org.apache.flink.api.common.functions.RuntimeContext runtimeContext) - throws Exception { - // Get job identifier from user configuration. - // If not configured, get from state. - jobIdentifier = agentPlan.getConfig().get(JOB_IDENTIFIER); - if (jobIdentifier == null) { - String initialJobIdentifier = runtimeContext.getJobInfo().getJobId().toString(); - jobIdentifier = - StateUtils.getSingleValueFromState( - context, "identifier_state", String.class, initialJobIdentifier); - } - } - void initOrIncSequenceNumber() throws Exception { // Initialize the sequence number state if it does not exist. Long sequenceNumber = sequenceNumberKState.value(); @@ -177,10 +157,6 @@ MapState getShortTermMemState() { return shortTermMemState; } - String getJobIdentifier() { - return jobIdentifier; - } - KeyGroupRange getCurrentSubtaskKeyGroupRange( int maxParallelism, org.apache.flink.api.common.functions.RuntimeContext runtimeContext) { From 85df6847a4544e54e682f62637122f7ed15c6e03 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Tue, 28 Apr 2026 20:45:29 -0700 Subject: [PATCH 09/14] [runtime] Annotate manager methods that may return null MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address xintongsong's review feedback (#546 comment #4): "I'd suggest to return an Optional or mark @Nullable for methods that may return null." Apply javax.annotation.Nullable across all 5 managers (consistent with the existing @Nullable ActionStateStore parameter style — Optional would be inconsistent with the codebase). 16 methods total: getter/poll methods backed by Map.get/Map.remove/poll-style state operations, plus fields that are null until a separate lifecycle phase initializes them (e.g., Python bridge components in no-Python jobs, lazily-set state descriptors). --- .../agents/runtime/operator/ActionTaskContextManager.java | 6 ++++++ .../agents/runtime/operator/DurableExecutionManager.java | 3 +++ .../apache/flink/agents/runtime/operator/EventRouter.java | 5 +++++ .../flink/agents/runtime/operator/OperatorStateManager.java | 6 ++++++ .../flink/agents/runtime/operator/PythonBridgeManager.java | 4 ++++ 5 files changed, 24 insertions(+) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java index 0daf51962..b304a6807 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java @@ -31,6 +31,8 @@ import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; import org.apache.flink.api.common.state.MapState; +import javax.annotation.Nullable; + import java.util.HashMap; import java.util.Map; @@ -154,6 +156,7 @@ void createAndSetRunnerContext( actionTask.setRunnerContext(context); } + @Nullable RunnerContextImpl.MemoryContext getMemoryContext(ActionTask actionTask) { return actionTaskMemoryContexts.get(actionTask); } @@ -162,6 +165,7 @@ void putMemoryContext(ActionTask actionTask, RunnerContextImpl.MemoryContext mem actionTaskMemoryContexts.put(actionTask, memoryContext); } + @Nullable RunnerContextImpl.MemoryContext removeMemoryContext(ActionTask actionTask) { return actionTaskMemoryContexts.remove(actionTask); } @@ -192,6 +196,7 @@ void transferContexts( } } + @Nullable ContinuationContext getContinuationContext(ActionTask actionTask) { return continuationContexts.get(actionTask); } @@ -208,6 +213,7 @@ boolean hasContinuationContext(ActionTask actionTask) { return continuationContexts.containsKey(actionTask); } + @Nullable String getPythonAwaitableRef(ActionTask actionTask) { return pythonAwaitableRefs.get(actionTask); } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java index 495002577..26564b2a2 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java @@ -103,6 +103,7 @@ void handleRecovery(OperatorStateBackend operatorStateBackend) throws Exception } } + @Nullable ActionState maybeGetActionState(Object key, long sequenceNum, Action action, Event event) throws Exception { return actionStateStore == null @@ -221,6 +222,7 @@ void recordCheckpointSequenceNumbers(long checkpointId, Map seqNum // --- Durable execution context map accessors --- + @Nullable RunnerContextImpl.DurableExecutionContext getDurableContext(ActionTask actionTask) { return actionTaskDurableContexts.get(actionTask); } @@ -239,6 +241,7 @@ boolean hasDurableContext(ActionTask actionTask) { } @VisibleForTesting + @Nullable ActionStateStore getActionStateStore() { return actionStateStore; } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java index e3cbd7f39..c8eb9501f 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java @@ -40,6 +40,8 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.types.Row; +import javax.annotation.Nullable; + import java.util.ArrayList; import java.util.List; @@ -137,15 +139,18 @@ void processEligibleWatermarks(WatermarkEmitter watermarkEmitter) throws Excepti } } + @Nullable SegmentedQueue getKeySegmentQueue() { return keySegmentQueue; } + @Nullable StreamRecord getReusedStreamRecord() { return reusedStreamRecord; } @VisibleForTesting + @Nullable EventLogger getEventLogger() { return eventLogger; } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java index 15bb3b500..ffee9f241 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java @@ -35,6 +35,8 @@ import org.apache.flink.runtime.state.VoidNamespace; import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import javax.annotation.Nullable; + import java.util.HashMap; import java.util.Map; @@ -117,6 +119,7 @@ boolean hasMoreActionTasks() throws Exception { return listStateNotEmpty(actionTasksKState); } + @Nullable ActionTask pollNextActionTask() throws Exception { return pollFromListState(actionTasksKState); } @@ -129,6 +132,7 @@ void addPendingInputEvent(Event event) throws Exception { pendingInputEventsKState.add(event); } + @Nullable Event pollNextPendingInputEvent() throws Exception { return pollFromListState(pendingInputEventsKState); } @@ -149,10 +153,12 @@ Iterable getProcessingKeys() throws Exception { return currentProcessingKeysOpState.get(); } + @Nullable MapState getSensoryMemState() { return sensoryMemState; } + @Nullable MapState getShortTermMemState() { return shortTermMemState; } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java index d23ecfb8b..46578fe02 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java @@ -37,6 +37,8 @@ import org.slf4j.LoggerFactory; import pemja.core.PythonInterpreter; +import javax.annotation.Nullable; + import java.util.HashMap; class PythonBridgeManager implements AutoCloseable { @@ -150,10 +152,12 @@ private void initPythonResourceAdapter(AgentPlan agentPlan, ResourceCache resour agentPlan.getResourceProviders(), pythonResourceAdapter, resourceCache); } + @Nullable PythonActionExecutor getPythonActionExecutor() { return pythonActionExecutor; } + @Nullable PythonRunnerContextImpl getPythonRunnerContext() { return pythonRunnerContext; } From 6991b9901c9550d4b84bd23d9804e5e5b3575d41 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Tue, 28 Apr 2026 21:14:04 -0700 Subject: [PATCH 10/14] [runtime] Document the contract of each manager class MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address xintongsong's review feedback (#546 top-level): "we need more JavaDocs to explicitly describe the contract of each component interfaces... Currently they are a bit vague relying on the names and implementations to understand." Add class-level Javadoc to all 5 managers (OperatorStateManager, DurableExecutionManager, ActionTaskContextManager, EventRouter, PythonBridgeManager) covering responsibility, owned state, lifecycle (when constructed, when init/open methods run, when close runs), and the design constraint that managers do not hold references to one another — cross-cutting data flows via method parameters with the operator as mediator. Add method-level Javadoc to contract-defining methods only — state initialization, sequence-number semantics, poll semantics, snapshot and recovery hooks, durable-mode no-op behavior, Java vs Python branching, watermark draining, and the close ordering inside the Python bridge. --- .../operator/ActionTaskContextManager.java | 89 ++++++++++++++- .../operator/DurableExecutionManager.java | 102 +++++++++++++++++- .../agents/runtime/operator/EventRouter.java | 80 ++++++++++++++ .../operator/OperatorStateManager.java | 81 ++++++++++++++ .../runtime/operator/PythonBridgeManager.java | 54 ++++++++++ 5 files changed, 401 insertions(+), 5 deletions(-) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java index b304a6807..5ffc37d17 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java @@ -36,6 +36,34 @@ import java.util.HashMap; import java.util.Map; +/** + * Owns the per-{@link ActionTask} runtime context bookkeeping for {@link ActionExecutionOperator}. + * + *

Owned state: + * + *

    + *
  • The shared (Java) {@link RunnerContextImpl} that is reused across action tasks via {@link + * RunnerContextImpl#switchActionContext}. + *
  • Three per-{@link ActionTask} maps that survive across the boundary between a finishing + * action and the action it generates: memory contexts, continuation contexts (for async Java + * actions), and Python awaitable references. + *
  • The {@link ContinuationActionExecutor} thread pool used to run async Java continuations. + *
+ * + *

Lifecycle: instantiated by the operator's {@code open()} with the configured async-thread + * count from the agent plan. Has no separate {@code open()} step — fully constructed in the + * operator's {@code open()}. {@link #close()} closes the shared runner context and the continuation + * executor. + * + *

Note: the Python {@link RunnerContextImpl} is not owned here — it is owned by {@link + * PythonBridgeManager} and passed in as a parameter to {@link #createOrGetRunnerContext} and {@link + * #createAndSetRunnerContext}. The durable-execution context map likewise lives on {@link + * DurableExecutionManager} and is accessed via the manager parameter passed to {@link + * #transferContexts}. + * + *

Design constraint: package-private; no manager-to-manager held references. Cross-cutting data + * flows via method parameters. + */ class ActionTaskContextManager implements AutoCloseable { private RunnerContextImpl runnerContext; @@ -53,6 +81,24 @@ class ActionTaskContextManager implements AutoCloseable { this.continuationActionExecutor = new ContinuationActionExecutor(numAsyncThreads); } + /** + * Returns a runner context for an action's exec language. + * + *

For Java actions, lazily creates a single {@link JavaRunnerContextImpl} that is reused for + * every Java action. For Python actions, returns the supplied {@link PythonRunnerContextImpl} + * (owned by {@link PythonBridgeManager}). Throws {@link IllegalStateException} if a Python + * context is requested but none was provided, or if the continuation executor has not been + * initialized. + * + * @param isJava {@code true} if the action is a Java action, {@code false} if Python. + * @param agentPlan the agent plan, used when creating the Java runner context. + * @param resourceCache the resource cache, used when creating the Java runner context. + * @param metricGroup the agent metric group. + * @param jobIdentifier the job identifier. + * @param mailboxThreadChecker hook used by runner contexts to assert mailbox-thread access. + * @param pythonRunnerContext the pre-built Python runner context, or {@code null} for Java. + * @return the runner context appropriate for the action's exec language. + */ RunnerContextImpl createOrGetRunnerContext( boolean isJava, AgentPlan agentPlan, @@ -86,6 +132,36 @@ RunnerContextImpl createOrGetRunnerContext( } } + /** + * Resolves the runner context for the given action task, switches it to that task's action, and + * wires its memory, continuation, and Python-awaitable contexts. + * + *

Steps: + * + *

    + *
  1. Selects a Java or Python runner context based on the action's {@code Exec} type. + *
  2. Reuses any existing {@link RunnerContextImpl.MemoryContext} for this task; otherwise + * builds a fresh one backed by the supplied sensory/short-term memory states. + *
  3. Calls {@link RunnerContextImpl#switchActionContext} so the shared context now points at + * this action's name, memory, and key namespace. + *
  4. For Java contexts, attaches a continuation context (re-used if the task is resuming + * from an async suspend, fresh otherwise). + *
  5. For Python contexts, attaches the per-task awaitable reference (or {@code null} if the + * awaitable was lost across a checkpoint restore — the action will then re-execute). + *
+ * + * @param actionTask the task to be set up before execution. + * @param key the current Flink key. + * @param agentPlan the agent plan. + * @param resourceCache the resource cache. + * @param metricGroup the agent metric group. + * @param jobIdentifier the job identifier. + * @param mailboxThreadChecker hook used to assert mailbox-thread access from runner contexts. + * @param sensoryMemState keyed map state backing sensory memory. + * @param shortTermMemState keyed map state backing short-term memory. + * @param pythonRunnerContext the Python runner context, or {@code null} when no Python runtime + * is initialized. + */ void createAndSetRunnerContext( ActionTask actionTask, Object key, @@ -171,8 +247,17 @@ RunnerContextImpl.MemoryContext removeMemoryContext(ActionTask actionTask) { } /** - * Transfers memory, durable execution, continuation, and Python awaitable contexts from the - * completed action task to the generated (next) action task. + * Transfers per-task contexts from a finishing action task to the action task it generated. + * + *

Always transfers the memory context. For Java tasks, transfers the continuation context. + * For Python tasks, transfers the awaitable reference when present. The durable-execution + * context map lives on {@link DurableExecutionManager}, so that manager is passed in as a + * parameter rather than held as a field — this keeps the no-manager-to-manager-references + * design constraint intact. + * + * @param fromTask the finishing task whose contexts should be transferred. + * @param toTask the newly generated task that will inherit the contexts. + * @param durableExecManager used to copy the durable-execution context entry, if any. */ void transferContexts( ActionTask fromTask, ActionTask toTask, DurableExecutionManager durableExecManager) { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java index 26564b2a2..91b8a3351 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java @@ -44,6 +44,31 @@ import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND; import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA; +/** + * Owns the durable-execution side of {@link ActionExecutionOperator}: the optional {@link + * ActionStateStore}, the recovery-marker operator state, the per-checkpoint sequence-number map, + * and the per-{@link ActionTask} {@link RunnerContextImpl.DurableExecutionContext} map. + * + *

Durable mode is optional. If no {@link ActionStateStore} is configured (and none is + * pre-injected via the constructor), all {@code maybe*} methods are no-ops and {@link + * #hasDurableStore()} returns {@code false}. This lets the operator stay agnostic of whether + * durable execution is enabled. + * + *

Lifecycle: instantiated in the operator constructor. {@link + * #maybeInitActionStateStore(AgentConfiguration)} runs from BOTH the operator's {@code + * initializeState()} and {@code open()} — recovery requires the store to be configured before + * {@link #handleRecovery(OperatorStateBackend)} reads from it, and the {@code open()} call ensures + * the store is also available on the normal (non-recovery) path. The method creates a default + * Kafka-backed store when one was not pre-injected, and is idempotent on the second call. {@link + * #handleRecovery(OperatorStateBackend)} runs from the operator's {@code initializeState()} during + * recovery. {@link #initRecoveryMarkerState(OperatorStateBackend)} runs from the operator's {@code + * open()}. {@link #close()} closes the underlying store. + * + *

Design constraint: package-private; no manager-to-manager held references. Cross-cutting data + * flows via method parameters. In particular, {@link + * ActionTaskContextManager#transferContexts(ActionTask, ActionTask, DurableExecutionManager)} + * accepts this manager as a parameter so that the durable-context map can stay encapsulated here. + */ class DurableExecutionManager implements ActionStatePersister, AutoCloseable { private static final Logger LOG = LoggerFactory.getLogger(DurableExecutionManager.class); @@ -57,12 +82,27 @@ class DurableExecutionManager implements ActionStatePersister, AutoCloseable { private final Map actionTaskDurableContexts; + /** + * @param actionStateStore an optional pre-injected store, primarily for tests. When {@code + * null}, {@link #maybeInitActionStateStore(AgentConfiguration)} may create a default store + * based on configuration; otherwise durable execution is disabled. + */ DurableExecutionManager(@Nullable ActionStateStore actionStateStore) { this.actionStateStore = actionStateStore; this.checkpointIdToSeqNums = new HashMap<>(); this.actionTaskDurableContexts = new HashMap<>(); } + /** + * Lazily creates a default {@link ActionStateStore} from configuration if none was + * pre-injected. + * + *

Only creates a store when this manager was constructed without one and the configuration + * selects a recognized backend (currently Kafka). Otherwise this is a no-op, which leaves + * durable execution disabled. + * + * @param config the agent configuration carrying the backend selection. + */ void maybeInitActionStateStore(AgentConfiguration config) { if (actionStateStore == null && KAFKA.getType().equalsIgnoreCase(config.get(ACTION_STATE_STORE_BACKEND))) { @@ -84,9 +124,18 @@ void initRecoveryMarkerState(OperatorStateBackend operatorStateBackend) throws E } } - // Note: Re-creates the union list state descriptor here because handleRecovery() is called - // from initializeState() which runs BEFORE open(), so recoveryMarkerOpState is not yet - // initialized. The descriptor name matches exactly, so Flink returns the same state. + /** + * Replays recovery markers from the operator's union-list state to rebuild durable action + * state. + * + *

Called from the operator's {@code initializeState()}, which runs before {@code open()}. + * This means {@link #recoveryMarkerOpState} is not yet initialized, so the union-list state + * descriptor is re-created here using the same descriptor name — Flink returns the same + * underlying state. No-op when durable execution is disabled. + * + * @param operatorStateBackend the operator state backend used to obtain the recovery-marker + * union-list state. + */ void handleRecovery(OperatorStateBackend operatorStateBackend) throws Exception { if (actionStateStore != null) { List markers = new ArrayList<>(); @@ -120,6 +169,21 @@ void maybeInitActionState(Object key, long sequenceNum, Action action, Event eve } } + /** + * Persists the result of a finished {@link ActionTask} to the durable store. + * + *

No-op when no store is configured or when the task did not finish (e.g. it suspended on a + * continuation). On finish, accumulates the task's memory updates and emitted output events + * into the {@link ActionState}, marks it completed, persists it, and clears the in-context + * durable bookkeeping. + * + * @param key the key under which the action ran. + * @param sequenceNum the per-key message sequence number. + * @param action the action being persisted. + * @param event the input event that triggered this action. + * @param context the runner context whose memory updates will be folded into the action state. + * @param actionTaskResult the result of running the action task. + */ void maybePersistTaskResult( Object key, long sequenceNum, @@ -157,6 +221,19 @@ void maybePersistTaskResult( context.clearDurableExecutionContext(); } + /** + * Wires a {@link RunnerContextImpl.DurableExecutionContext} onto the given action task's runner + * context. + * + *

Returns immediately when no durable store is configured. Otherwise reuses an existing + * {@link RunnerContextImpl.DurableExecutionContext} held in the per-task map (i.e. when + * resuming a continuation), or creates a fresh one bound to this manager so that nested + * persists route back through {@link #persist}. + * + * @param actionTask the action task to attach the context to. + * @param actionState the action state for this (key, sequenceNum, action, event). + * @param seqNum the per-key sequence number. + */ void setupDurableExecutionContext(ActionTask actionTask, ActionState actionState, long seqNum) { if (actionStateStore == null) { return; @@ -196,6 +273,16 @@ void maybePruneState(Object key, long sequenceNum) throws Exception { } } + /** + * Prunes durable state for all per-key sequence numbers that were captured at the time of the + * given checkpoint. + * + *

The mapping from checkpoint id to per-key sequence numbers must have been recorded earlier + * via {@link #recordCheckpointSequenceNumbers}. After pruning, the entry for that checkpoint is + * removed. No-op when durable execution is disabled. + * + * @param checkpointId the id of the completed checkpoint. + */ void notifyCheckpointComplete(long checkpointId) { if (actionStateStore != null) { Map keyToSeqNum = @@ -216,6 +303,15 @@ void snapshotRecoveryMarker() throws Exception { } } + /** + * Records the per-key sequence numbers observed when snapshotting the given checkpoint. + * + *

The recorded mapping is consulted later by {@link #notifyCheckpointComplete(long)} to + * prune durable state up to the sequence number that was committed by that checkpoint. + * + * @param checkpointId the checkpoint being snapshotted. + * @param seqNums the per-key sequence numbers captured during the snapshot. + */ void recordCheckpointSequenceNumbers(long checkpointId, Map seqNums) { checkpointIdToSeqNums.put(checkpointId, seqNums); } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java index c8eb9501f..3c374a591 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java @@ -49,6 +49,33 @@ import static org.apache.flink.agents.api.configuration.AgentConfigOptions.PRETTY_PRINT; import static org.apache.flink.util.Preconditions.checkState; +/** + * Handles event-side concerns for {@link ActionExecutionOperator}: input/output transformation + * between Java/Python representations, action lookup against the {@link AgentPlan}, event-logger + * and event-listener notification, and watermark draining via the per-key segment queue. + * + *

Owned state: + * + *

    + *
  • The {@link EventLogger} created from the agent plan's logging configuration (may be {@code + * null} when logging is disabled). + *
  • The list of registered {@link EventListener}s. + *
  • A reused {@link StreamRecord} used to emit outputs without per-record allocation. + *
  • The {@link SegmentedQueue} that orders watermarks behind in-flight keys so a watermark is + * only emitted once all keys ahead of it have finished. + *
  • The late-bound {@link BuiltInMetrics} provided in {@link #open(BuiltInMetrics)}. + *
+ * + *

Lifecycle: instantiated in the operator constructor (which decides {@link #inputIsJava}). + * {@link #open(BuiltInMetrics)} runs from the operator's {@code open()} once metrics are available. + * {@link #initEventLogger} also runs from the operator's {@code open()} once the runtime context is + * available (after metrics have been built). {@link #close()} closes the event logger. + * + *

Design constraint: package-private; no manager-to-manager held references. + * + * @param input record type + * @param output record type + */ class EventRouter implements AutoCloseable { private final boolean inputIsJava; @@ -64,6 +91,15 @@ class EventRouter implements AutoCloseable { this.eventListeners = new ArrayList<>(); } + /** + * Initializes mutable runtime state that depends on metrics being available. + * + *

Allocates the reused stream record and the segmented watermark queue, and stores the + * supplied {@link BuiltInMetrics} for use in {@link #notifyEventProcessed(Event)}. Called from + * the operator's {@code open()} once metric groups are constructed. + * + * @param builtInMetrics the operator's built-in metrics handle. + */ void open(BuiltInMetrics builtInMetrics) { this.reusedStreamRecord = new StreamRecord<>(null); this.keySegmentQueue = new SegmentedQueue(); @@ -77,6 +113,19 @@ void initEventLogger(StreamingRuntimeContext runtimeContext) throws Exception { eventLogger.open(new EventLoggerOpenParams(runtimeContext)); } + /** + * Wraps an incoming record into an {@link Event} suitable for action dispatch. + * + *

Java pipelines wrap the raw input directly into a Java {@link InputEvent}. Python + * pipelines expect a two-field {@link Row} where the first field is the key and the second is + * the actual payload; the payload is converted to a Python event via the supplied {@link + * PythonActionExecutor}. + * + * @param input the raw input record. + * @param pythonActionExecutor the Python action executor (used only when input originates from + * Python). + * @return the wrapped input event. + */ @SuppressWarnings("unchecked") Event wrapToInputEvent(IN input, PythonActionExecutor pythonActionExecutor) { if (inputIsJava) { @@ -89,6 +138,18 @@ Event wrapToInputEvent(IN input, PythonActionExecutor pythonActionExecutor) { } } + /** + * Extracts the downstream output payload from an {@link OutputEvent}. + * + *

For a Java {@link OutputEvent}, returns the payload directly. For a Python {@link + * PythonEvent}, delegates to the supplied {@link PythonActionExecutor} to convert the Python + * output back into the Java output type. + * + * @param event the output event (must satisfy {@link EventUtil#isOutputEvent(Event)}). + * @param pythonActionExecutor the Python action executor (used only for Python events). + * @return the typed output payload. + * @throws IllegalStateException if the event is not a recognized output-event type. + */ @SuppressWarnings("unchecked") OUT getOutputFromOutputEvent(Event event, PythonActionExecutor pythonActionExecutor) { checkState(EventUtil.isOutputEvent(event)); @@ -112,6 +173,16 @@ List getActionsTriggeredBy(Event event, AgentPlan agentPlan) { } } + /** + * Notifies the configured event sinks (logger, listeners, metrics) that an event was processed. + * + *

If event logging is enabled, appends and immediately flushes the event. Then notifies + * every registered {@link EventListener}. Finally increments the {@code eventProcessed} + * built-in metric. The event logger is flushed per call as a temporary measure pending a + * batched flush mechanism. + * + * @param event the event that was just processed. + */ void notifyEventProcessed(Event event) throws Exception { EventContext eventContext = new EventContext(event); if (eventLogger != null) { @@ -131,6 +202,15 @@ void notifyEventProcessed(Event event) throws Exception { builtInMetrics.markEventProcessed(); } + /** + * Drains all watermarks from the segmented queue that are now eligible to be emitted. + * + *

A watermark becomes eligible once every key in the segment ahead of it has finished + * processing. This method pops watermarks in order and forwards each to the supplied {@link + * WatermarkEmitter}. + * + * @param watermarkEmitter callback that emits a watermark downstream. + */ void processEligibleWatermarks(WatermarkEmitter watermarkEmitter) throws Exception { Watermark mark = keySegmentQueue.popOldestWatermark(); while (mark != null) { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java index ffee9f241..8eae9c835 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java @@ -42,6 +42,31 @@ import static org.apache.flink.agents.runtime.utils.StateUtil.*; +/** + * Owns all Flink state used by {@link ActionExecutionOperator} and exposes a narrow API for + * accessing it. + * + *

Owned state: + * + *

    + *
  • Keyed list state of pending {@link ActionTask}s for the current key. + *
  • Keyed list state of pending {@link Event}s buffered while another input is processing. + *
  • Keyed value state holding the per-key message sequence number. + *
  • Keyed map states for sensory and short-term memory. + *
  • Operator union-list state of keys currently being processed (used after rescale to recover + * in-flight work). + *
+ * + *

Lifecycle: instantiated by the operator's {@code initializeState()} (the Flink lifecycle runs + * {@code initializeState} before {@code open}). Both {@link + * #initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext)} and {@link + * #initializeOperatorStates(OperatorStateBackend)} are invoked later from the operator's {@code + * open()}. There is no explicit close — the underlying state handles are owned by Flink. + * + *

Design constraint: package-private; no manager-to-manager held references. Cross-cutting data + * flows via method parameters (see for example {@link ActionTaskContextManager#transferContexts} + * which takes a {@link DurableExecutionManager} as an argument rather than holding one). + */ class OperatorStateManager { static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = "messageSequenceNumber"; @@ -56,6 +81,15 @@ class OperatorStateManager { OperatorStateManager() {} + /** + * Registers all keyed-state descriptors against the operator's runtime context. + * + *

Registers: sensory memory map state, short-term memory map state, the per-key message + * sequence-number value state, the per-key list of pending {@link ActionTask}s, and the per-key + * list of buffered {@link Event}s. Called from the operator's {@code open()} method. + * + * @param runtimeContext the operator's runtime context, used to obtain keyed state handles. + */ void initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext runtimeContext) throws Exception { // init sensoryMemState @@ -90,6 +124,16 @@ void initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class))); } + /** + * Registers operator-level (non-keyed) state. + * + *

Registers the {@code currentProcessingKeys} union-list state. A union-list lets every + * subtask see all keys after a rescale; the operator filters out keys that do not belong to the + * current subtask's key-group range during recovery. Called from the operator's {@code open()} + * method (after {@code super.open()} and after the keyed-state setup). + * + * @param operatorStateBackend the operator state backend used to obtain operator state. + */ void initializeOperatorStates(OperatorStateBackend operatorStateBackend) throws Exception { // We use UnionList here to ensure that the task can access all keys after parallelism // modifications. @@ -101,6 +145,12 @@ void initializeOperatorStates(OperatorStateBackend operatorStateBackend) throws "currentProcessingKeys", TypeInformation.of(Object.class))); } + /** + * Advances the per-key message sequence number. + * + *

If the state has no value for the current key, sets it to {@code 0L}. Otherwise increments + * the existing value by one. Must be called under a keyed context. + */ void initOrIncSequenceNumber() throws Exception { // Initialize the sequence number state if it does not exist. Long sequenceNumber = sequenceNumberKState.value(); @@ -119,6 +169,12 @@ boolean hasMoreActionTasks() throws Exception { return listStateNotEmpty(actionTasksKState); } + /** + * Removes and returns the next pending {@link ActionTask} for the current key. + * + * @return the next {@link ActionTask}, or {@code null} if the queue for the current key is + * empty. + */ @Nullable ActionTask pollNextActionTask() throws Exception { return pollFromListState(actionTasksKState); @@ -132,6 +188,12 @@ void addPendingInputEvent(Event event) throws Exception { pendingInputEventsKState.add(event); } + /** + * Removes and returns the next pending input {@link Event} buffered for the current key. + * + * @return the next buffered input {@link Event}, or {@code null} if the buffer for the current + * key is empty. + */ @Nullable Event pollNextPendingInputEvent() throws Exception { return pollFromListState(pendingInputEventsKState); @@ -178,6 +240,16 @@ boolean isKeyOwnedByCurrentSubtask( return currentSubtaskKeyGroupRange.contains(keyGroup); } + /** + * Captures the current per-key sequence numbers across all keys held by the given backend. + * + *

Invoked during checkpoint snapshotting so the caller can later associate the snapshot's + * per-key sequence numbers with a checkpoint id (see {@link + * DurableExecutionManager#recordCheckpointSequenceNumbers}). + * + * @param keyedStateBackend the keyed state backend to scan. + * @return an immutable map snapshot from key to its current sequence number. + */ @SuppressWarnings("unchecked") Map snapshotSequenceNumbers(KeyedStateBackend keyedStateBackend) throws Exception { @@ -191,6 +263,15 @@ Map snapshotSequenceNumbers(KeyedStateBackend keyedStateBackend return keyToSeqNum; } + /** + * Applies a function to the pending-input-event list state for every key in the backend. + * + *

Used during recovery to scan all keys that hold buffered input events so the operator can + * resume processing them after a restore. + * + * @param keyedStateBackend the keyed state backend to scan. + * @param function the function to apply per (key, list-state) pair. + */ @SuppressWarnings("unchecked") void forEachPendingInputEventKey( KeyedStateBackend keyedStateBackend, diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java index 46578fe02..ba765924f 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java @@ -41,6 +41,31 @@ import java.util.HashMap; +/** + * Owns the embedded Python runtime used by {@link ActionExecutionOperator} when an agent plan + * contains Python actions or Python-defined resources. + * + *

Owned state: + * + *

    + *
  • The {@link PythonEnvironmentManager} that prepares dependencies and the Pemja runtime. + *
  • The {@link PythonInterpreter} obtained from that environment. + *
  • The {@link PythonActionExecutor} (when the plan contains Python actions). + *
  • The {@link PythonRunnerContextImpl} consumed by Python actions. + *
  • The Java/Python resource adapters that bridge resource lookups across languages. + *
+ * + *

Lifecycle: instantiated by the operator's {@code open()} (lazy — not in the operator + * constructor), then immediately initialized via {@link #open} in the same call. {@link #open} is a + * no-op when the agent plan contains no Python actions and no Python resources — in that case all + * accessors return {@code null} and {@link #isInitialized()} returns {@code false}. {@link + * #close()} closes the owned resources in the reverse order of creation: {@code + * pythonActionExecutor} → {@code pythonInterpreter} → {@code pythonEnvironmentManager}. + * + *

Design constraint: package-private; no manager-to-manager held references. Other managers + * receive what they need (e.g. the Python runner context, the action executor) via method + * parameters. + */ class PythonBridgeManager implements AutoCloseable { private static final Logger LOG = LoggerFactory.getLogger(PythonBridgeManager.class); @@ -57,6 +82,27 @@ class PythonBridgeManager implements AutoCloseable { this.initialized = false; } + /** + * Initializes the Python runtime if the agent plan needs it. + * + *

Scans the agent plan for any {@link PythonFunction} action or {@link + * PythonResourceProvider}. If neither is present, this method is a no-op and {@link + * #isInitialized()} stays {@code false}. Otherwise it builds the {@link + * PythonEnvironmentManager}, opens an embedded {@link PythonInterpreter}, constructs the shared + * {@link PythonRunnerContextImpl}, wires the Java/Python resource adapters, and conditionally + * initializes the Python action executor and the Python resource adapter (each only when the + * corresponding component is present in the plan). + * + * @param agentPlan the agent plan describing actions and resources. + * @param resourceCache the resource cache visible to both languages. + * @param executionConfig used to derive Python dependency information. + * @param distributedCache used to resolve distributed Python files. + * @param tmpDirs Flink-managed temp directories made available to Python. + * @param jobId the Flink job id. + * @param metricGroup the agent metric group, exposed to Python via the runner context. + * @param mailboxThreadChecker hook used by the runner context to assert mailbox-thread access. + * @param jobIdentifier the job identifier used to scope Python state. + */ void open( AgentPlan agentPlan, ResourceCache resourceCache, @@ -152,11 +198,19 @@ private void initPythonResourceAdapter(AgentPlan agentPlan, ResourceCache resour agentPlan.getResourceProviders(), pythonResourceAdapter, resourceCache); } + /** + * @return the Python action executor, or {@code null} if the agent plan contains no Python + * actions (or {@link #open} has not yet been called). + */ @Nullable PythonActionExecutor getPythonActionExecutor() { return pythonActionExecutor; } + /** + * @return the Python runner context, or {@code null} if no Python runtime was initialized + * because the agent plan has neither Python actions nor Python resources. + */ @Nullable PythonRunnerContextImpl getPythonRunnerContext() { return pythonRunnerContext; From e2e4e03fb45c3223fe0bad17f91b0708eb923a3b Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Tue, 28 Apr 2026 21:59:57 -0700 Subject: [PATCH 11/14] [runtime] Add minimal contract tests for manager classes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address xintongsong's review feedback (#546 top-level): "unit tests to verify them. Currently they are a bit vague relying on the names and implementations to understand." Add focused contract tests covering the core promise of each manager: - DurableExecutionManagerTest (2): no-store mode is a silent no-op for every maybe* method; an injected store correctly receives a finished task's persisted result with isCompleted() set. - ActionTaskContextManagerTest (2): per-task maps remain isolated under put/get/remove cycles; createOrGetRunnerContext throws IllegalStateException with a helpful message when asked for a Python context but pythonRunnerContext is null. - EventRouterTest (2): wrapToInputEvent passes Java input through unchanged; getActionsTriggeredBy dispatches by event-class name. - PythonBridgeManagerTest (1): open() is a no-op when the agent plan contains neither Python actions nor Python resources. OperatorStateManager is left to integration coverage from ActionExecutionOperatorTest (21 cases) — a dedicated harness-based test is tracked as follow-up. Comprehensive coverage (snapshot/recovery, listener notification, watermark draining, full memory-context lifecycle) is also follow-up. --- .../ActionTaskContextManagerTest.java | 75 ++++++++++++++ .../operator/DurableExecutionManagerTest.java | 99 +++++++++++++++++++ .../runtime/operator/EventRouterTest.java | 59 +++++++++++ .../operator/PythonBridgeManagerTest.java | 61 ++++++++++++ .../agents/runtime/operator/TestActions.java | 55 +++++++++++ 5 files changed, 349 insertions(+) create mode 100644 runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManagerTest.java create mode 100644 runtime/src/test/java/org/apache/flink/agents/runtime/operator/DurableExecutionManagerTest.java create mode 100644 runtime/src/test/java/org/apache/flink/agents/runtime/operator/EventRouterTest.java create mode 100644 runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java create mode 100644 runtime/src/test/java/org/apache/flink/agents/runtime/operator/TestActions.java diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManagerTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManagerTest.java new file mode 100644 index 000000000..69b7ecb56 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManagerTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.async.ContinuationContext; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Contract tests for {@link ActionTaskContextManager}. */ +class ActionTaskContextManagerTest { + + @Test + void perTaskMapsAreIsolatedAcrossPutGetRemove() throws Exception { + try (ActionTaskContextManager mgr = new ActionTaskContextManager(1)) { + Action action = TestActions.noopAction(); + ActionTask t1 = new JavaActionTask("k", new InputEvent(1L), action); + ActionTask t2 = new JavaActionTask("k", new InputEvent(2L), action); + + ContinuationContext c1 = new ContinuationContext(); + mgr.putContinuationContext(t1, c1); + mgr.putPythonAwaitableRef(t2, "ref-2"); + + // Cross-task isolation: each map only carries the entry it was given. + assertThat(mgr.getContinuationContext(t1)).isSameAs(c1); + assertThat(mgr.getContinuationContext(t2)).isNull(); + assertThat(mgr.getPythonAwaitableRef(t1)).isNull(); + assertThat(mgr.getPythonAwaitableRef(t2)).isEqualTo("ref-2"); + assertThat(mgr.hasContinuationContext(t1)).isTrue(); + assertThat(mgr.hasContinuationContext(t2)).isFalse(); + + // Remove and re-check + mgr.removeContinuationContext(t1); + mgr.removePythonAwaitableRef(t2); + assertThat(mgr.hasContinuationContext(t1)).isFalse(); + assertThat(mgr.getPythonAwaitableRef(t2)).isNull(); + } + } + + @Test + void createOrGetRunnerContextThrowsWhenPythonContextRequestedButNull() throws Exception { + try (ActionTaskContextManager mgr = new ActionTaskContextManager(1)) { + assertThatThrownBy( + () -> + mgr.createOrGetRunnerContext( + /* isJava */ false, + /* agentPlan */ null, + /* resourceCache */ null, + /* metricGroup */ null, + /* jobIdentifier */ "job", + /* mailboxThreadChecker */ () -> {}, + /* pythonRunnerContext */ null)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("PythonRunnerContextImpl has not been initialized"); + } + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/DurableExecutionManagerTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/DurableExecutionManagerTest.java new file mode 100644 index 000000000..98a8b4091 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/DurableExecutionManagerTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.actionstate.ActionState; +import org.apache.flink.agents.runtime.actionstate.InMemoryActionStateStore; +import org.apache.flink.agents.runtime.context.RunnerContextImpl; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** Contract tests for {@link DurableExecutionManager}. */ +class DurableExecutionManagerTest { + + @Test + void noStoreModeMakesAllMaybeOperationsNoOp() throws Exception { + DurableExecutionManager dem = new DurableExecutionManager(null); + // No ACTION_STATE_STORE_BACKEND set → no default store should be created. + dem.maybeInitActionStateStore(new AgentConfiguration()); + + assertThat(dem.hasDurableStore()).isFalse(); + assertThat(dem.getActionStateStore()).isNull(); + + Action action = TestActions.noopAction(); + Event event = new InputEvent(0L); + + // Every maybe* method must be a silent no-op. + assertThat(dem.maybeGetActionState("k", 0L, action, event)).isNull(); + dem.maybeInitActionState("k", 0L, action, event); + dem.maybePruneState("k", 0L); + dem.notifyCheckpointComplete(1L); + dem.snapshotRecoveryMarker(); + dem.close(); + } + + @Test + void withInjectedStorePersistsTaskResult() throws Exception { + InMemoryActionStateStore store = new InMemoryActionStateStore(false); + DurableExecutionManager dem = new DurableExecutionManager(store); + + assertThat(dem.hasDurableStore()).isTrue(); + assertThat(dem.getActionStateStore()).isSameAs(store); + + Action action = TestActions.noopAction(); + Event event = new InputEvent(42L); + String key = "key-1"; + long seq = 0L; + + // First call seeds an initial ActionState in the store. + dem.maybeInitActionState(key, seq, action, event); + assertThat(store.getKeyedActionStates()).containsKey(key); + assertThat(dem.maybeGetActionState(key, seq, action, event)).isNotNull(); + + // Build a finished task result with one output event; verify persist folds it into state. + Event outEvent = new OutputEvent(99L); + RunnerContextImpl context = mock(RunnerContextImpl.class); + when(context.getSensoryMemoryUpdates()).thenReturn(List.of()); + when(context.getShortTermMemoryUpdates()).thenReturn(List.of()); + + ActionTask.ActionTaskResult finishedResult = mock(ActionTask.ActionTaskResult.class); + when(finishedResult.isFinished()).thenReturn(true); + when(finishedResult.getOutputEvents()).thenReturn(List.of(outEvent)); + + dem.maybePersistTaskResult(key, seq, action, event, context, finishedResult); + + ActionState persisted = dem.maybeGetActionState(key, seq, action, event); + assertThat(persisted).isNotNull(); + assertThat(persisted.getOutputEvents()).contains(outEvent); + assertThat(persisted.isCompleted()).isTrue(); + verify(context).clearDurableExecutionContext(); + + dem.close(); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/EventRouterTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/EventRouterTest.java new file mode 100644 index 000000000..a93ab70f1 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/EventRouterTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.actions.Action; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Contract tests for {@link EventRouter}. */ +class EventRouterTest { + + @Test + void wrapToInputEventReturnsJavaInputEventForJavaInput() { + AgentPlan plan = new AgentPlan(new HashMap<>(), new HashMap<>()); + EventRouter router = new EventRouter<>(plan, /* inputIsJava */ true); + + Event event = router.wrapToInputEvent(42L, /* pythonActionExecutor */ null); + + assertThat(event).isInstanceOf(InputEvent.class); + assertThat(((InputEvent) event).getInput()).isEqualTo(42L); + } + + @Test + void getActionsTriggeredByReturnsActionsForJavaEventClass() throws Exception { + Action action = TestActions.noopAction(); + Map actions = Map.of(action.getName(), action); + Map> byEvent = Map.of(InputEvent.class.getName(), List.of(action)); + AgentPlan plan = new AgentPlan(actions, byEvent); + + EventRouter router = new EventRouter<>(plan, /* inputIsJava */ true); + + List triggered = router.getActionsTriggeredBy(new InputEvent(0L), plan); + + assertThat(triggered).containsExactly(action); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java new file mode 100644 index 000000000..f7226826e --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Contract tests for {@link PythonBridgeManager}. */ +class PythonBridgeManagerTest { + + @Test + void openIsNoOpWhenPlanHasNeitherPythonActionsNorResources() throws Exception { + // Java-only plan: one Java action, no resources. + Action javaAction = TestActions.noopAction(); + Map actions = Map.of(javaAction.getName(), javaAction); + Map> byEvent = Map.of(InputEvent.class.getName(), List.of(javaAction)); + AgentPlan plan = new AgentPlan(actions, byEvent); + + try (PythonBridgeManager bridge = new PythonBridgeManager()) { + bridge.open( + plan, + /* resourceCache */ null, + new ExecutionConfig(), + /* distributedCache */ null, + /* tmpDirs */ new String[] {System.getProperty("java.io.tmpdir")}, + /* jobId */ new JobID(), + /* metricGroup */ null, + /* mailboxThreadChecker */ () -> {}, + /* jobIdentifier */ "job-1"); + + // No-op contract: nothing initialized, no Pemja interpreter created. + assertThat(bridge.isInitialized()).isFalse(); + assertThat(bridge.getPythonActionExecutor()).isNull(); + assertThat(bridge.getPythonRunnerContext()).isNull(); + } + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/TestActions.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/TestActions.java new file mode 100644 index 000000000..f427814a0 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/TestActions.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.agents.plan.actions.Action; + +import java.util.List; + +/** + * Shared helpers for the manager contract tests in this package. + * + *

Provides a minimal {@link Action} backed by a no-op static Java function so individual tests + * do not need to redeclare the boilerplate around {@link JavaFunction#JavaFunction(Class, String, + * Class[])} signature checks. + */ +final class TestActions { + + private TestActions() {} + + /** Returns a minimal noop Java action backed by {@link #noop(InputEvent, RunnerContext)}. */ + static Action noopAction() { + try { + return new Action( + "noop", + new JavaFunction( + TestActions.class, + "noop", + new Class[] {InputEvent.class, RunnerContext.class}), + List.of(InputEvent.class.getName())); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** No-op static method referenced by {@link #noopAction()}. Must be public for reflection. */ + public static void noop(InputEvent event, RunnerContext context) {} +} From 3699f140414a354daf411a0768d78f086c2b61df Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Thu, 30 Apr 2026 23:18:49 -0700 Subject: [PATCH 12/14] [runtime] Remove @Nullable from always-initialized getters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit getKeySegmentQueue() and getReusedStreamRecord() in EventRouter are initialized unconditionally in open() before any data flows through. Similarly, getSensoryMemState() and getShortTermMemState() in OperatorStateManager are initialized before callers access them. @Nullable on these methods is misleading — retain it only where the return value genuinely can be null at call time (getEventLogger, pollNextActionTask, pollNextPendingInputEvent). Addresses wenjin272's review comments r3165417289 + r3165421647. --- .../org/apache/flink/agents/runtime/operator/EventRouter.java | 2 -- .../flink/agents/runtime/operator/OperatorStateManager.java | 2 -- 2 files changed, 4 deletions(-) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java index 3c374a591..be3ea76dd 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java @@ -219,12 +219,10 @@ void processEligibleWatermarks(WatermarkEmitter watermarkEmitter) throws Excepti } } - @Nullable SegmentedQueue getKeySegmentQueue() { return keySegmentQueue; } - @Nullable StreamRecord getReusedStreamRecord() { return reusedStreamRecord; } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java index 8eae9c835..785d43beb 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java @@ -215,12 +215,10 @@ Iterable getProcessingKeys() throws Exception { return currentProcessingKeysOpState.get(); } - @Nullable MapState getSensoryMemState() { return sensoryMemState; } - @Nullable MapState getShortTermMemState() { return shortTermMemState; } From 07271af1037834e014c1730d26355160240fb660 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Fri, 1 May 2026 00:12:48 -0700 Subject: [PATCH 13/14] Retrigger CI From 3b9e9330cef5879620a1b0fde9a54c4936d6121a Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Fri, 1 May 2026 04:48:37 -0700 Subject: [PATCH 14/14] Retrigger CI