diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index e5015a3a5..2d8933a3c 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -18,96 +18,47 @@ package org.apache.flink.agents.runtime.operator; import org.apache.flink.agents.api.Event; -import org.apache.flink.agents.api.EventContext; -import org.apache.flink.agents.api.InputEvent; import org.apache.flink.agents.api.OutputEvent; import org.apache.flink.agents.api.agents.AgentExecutionOptions; import org.apache.flink.agents.api.context.MemoryUpdate; -import org.apache.flink.agents.api.listener.EventListener; -import org.apache.flink.agents.api.logger.EventLogger; -import org.apache.flink.agents.api.logger.EventLoggerConfig; -import org.apache.flink.agents.api.logger.EventLoggerFactory; -import org.apache.flink.agents.api.logger.EventLoggerOpenParams; -import org.apache.flink.agents.api.resource.Resource; -import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.plan.AgentPlan; import org.apache.flink.agents.plan.JavaFunction; import org.apache.flink.agents.plan.PythonFunction; import org.apache.flink.agents.plan.actions.Action; -import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; -import org.apache.flink.agents.runtime.PythonMCPResourceDiscovery; import org.apache.flink.agents.runtime.ResourceCache; import org.apache.flink.agents.runtime.actionstate.ActionState; import org.apache.flink.agents.runtime.actionstate.ActionStateStore; -import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore; -import org.apache.flink.agents.runtime.async.ContinuationActionExecutor; -import org.apache.flink.agents.runtime.async.ContinuationContext; -import org.apache.flink.agents.runtime.context.ActionStatePersister; -import org.apache.flink.agents.runtime.context.JavaRunnerContextImpl; -import org.apache.flink.agents.runtime.context.RunnerContextImpl; -import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; -import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; -import org.apache.flink.agents.runtime.eventlog.FileEventLogger; -import org.apache.flink.agents.runtime.memory.CachedMemoryStore; -import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; -import org.apache.flink.agents.runtime.operator.queue.SegmentedQueue; -import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; -import org.apache.flink.agents.runtime.python.event.PythonEvent; import org.apache.flink.agents.runtime.python.operator.PythonActionTask; -import org.apache.flink.agents.runtime.python.utils.JavaResourceAdapter; -import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; -import org.apache.flink.agents.runtime.python.utils.PythonResourceAdapterImpl; import org.apache.flink.agents.runtime.utils.EventUtil; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.operators.MailboxExecutor; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.state.MapState; -import org.apache.flink.api.common.state.MapStateDescriptor; -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.python.env.PythonDependencyInfo; import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyGroupRangeAssignment; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.runtime.state.VoidNamespace; -import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.BoundedOneInput; import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.Output; -import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.streaming.runtime.tasks.StreamTask; import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxExecutorImpl; import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxProcessor; -import org.apache.flink.types.Row; import org.apache.flink.util.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import pemja.core.PythonInterpreter; import java.lang.reflect.Field; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND; -import static org.apache.flink.agents.api.configuration.AgentConfigOptions.BASE_LOG_DIR; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.JOB_IDENTIFIER; -import static org.apache.flink.agents.api.configuration.AgentConfigOptions.PRETTY_PRINT; -import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA; -import static org.apache.flink.agents.runtime.utils.StateUtil.*; import static org.apache.flink.util.Preconditions.checkState; /** @@ -121,54 +72,25 @@ * and the resulting output event is collected for further processing. */ public class ActionExecutionOperator extends AbstractStreamOperator - implements OneInputStreamOperator, BoundedOneInput, ActionStatePersister { + implements OneInputStreamOperator, BoundedOneInput { private static final long serialVersionUID = 1L; private static final Logger LOG = LoggerFactory.getLogger(ActionExecutionOperator.class); - private static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker"; - private static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = "messageSequenceNumber"; - private static final String PENDING_INPUT_EVENT_STATE_NAME = "pendingInputEvents"; - private final AgentPlan agentPlan; private transient ResourceCache resourceCache; - private final Boolean inputIsJava; - - private transient StreamRecord reusedStreamRecord; - - private transient MapState sensoryMemState; - - private transient MapState shortTermMemState; - - private transient PythonEnvironmentManager pythonEnvironmentManager; - - private transient PythonInterpreter pythonInterpreter; - - // PythonActionExecutor for Python actions - private transient PythonActionExecutor pythonActionExecutor; - - // RunnerContext for Python actions - private transient PythonRunnerContextImpl pythonRunnerContext; - - // PythonResourceAdapter for Python resources in Java actions - private transient PythonResourceAdapterImpl pythonResourceAdapter; - - // PythonResourceAdapter for Java resources in Python actions or Python resources - private transient JavaResourceAdapter javaResourceAdapter; + private transient PythonBridgeManager pythonBridge; private transient FlinkAgentsMetricGroupImpl metricGroup; private transient BuiltInMetrics builtInMetrics; - private transient SegmentedQueue keySegmentQueue; - private final transient MailboxExecutor mailboxExecutor; - // RunnerContext for Java Actions - private transient RunnerContextImpl runnerContext; + private transient ActionTaskContextManager contextManager; // We need to check whether the current thread is the mailbox thread using the mailbox // processor. @@ -177,50 +99,14 @@ public class ActionExecutionOperator extends AbstractStreamOperator actionTasksKState; - - // To avoid processing different InputEvents with the same key, we use a state to store pending - // InputEvents that are waiting to be processed. - private transient ListState pendingInputEventsKState; - - // An operator state is used to track the currently processing keys. This is useful when - // receiving an EndOfInput signal, as we need to wait until all related events are fully - // processed. - private transient ListState currentProcessingKeysOpState; - - private final transient EventLogger eventLogger; - private final transient List eventListeners; - - private transient ActionStateStore actionStateStore; - private transient ValueState sequenceNumberKState; - private transient ListState recoveryMarkerOpState; - private transient Map> checkpointIdToSeqNums; + private final transient EventRouter eventRouter; - // This in memory map keep track of the runner context for the async action task that having - // been finished - private final transient Map - actionTaskMemoryContexts; + private final transient DurableExecutionManager durableExecManager; - // This in memory map keeps track of the durable execution context for async action tasks - // that have not been finished, allowing recovery of currentCallIndex across invocations - private final transient Map - actionTaskDurableContexts; + private transient OperatorStateManager stateManager; - private final transient Map continuationContexts; - - private final transient Map pythonAwaitableRefs; - - // Each job can only have one identifier and this identifier must be consistent across restarts. - // We cannot use job id as the identifier here because user may change job id by - // creating a savepoint, stop the job and then resume from savepoint. - // We use this identifier to control the visibility for long-term memory. - // Inspired by Apache Paimon. private transient String jobIdentifier; - private transient ContinuationActionExecutor continuationActionExecutor; - public ActionExecutionOperator( AgentPlan agentPlan, Boolean inputIsJava, @@ -228,17 +114,10 @@ public ActionExecutionOperator( MailboxExecutor mailboxExecutor, ActionStateStore actionStateStore) { this.agentPlan = agentPlan; - this.inputIsJava = inputIsJava; this.processingTimeService = processingTimeService; this.mailboxExecutor = mailboxExecutor; - this.eventLogger = createEventLogger(agentPlan); - this.eventListeners = new ArrayList<>(); - this.actionStateStore = actionStateStore; - this.checkpointIdToSeqNums = new HashMap<>(); - this.actionTaskMemoryContexts = new HashMap<>(); - this.actionTaskDurableContexts = new HashMap<>(); - this.continuationContexts = new HashMap<>(); - this.pythonAwaitableRefs = new HashMap<>(); + this.eventRouter = new EventRouter<>(agentPlan, inputIsJava); + this.durableExecManager = new DurableExecutionManager(actionStateStore); OperatorUtils.setChainStrategy(this, ChainingStrategy.ALWAYS); } @@ -253,82 +132,42 @@ public void setup( @Override public void open() throws Exception { super.open(); - reusedStreamRecord = new StreamRecord<>(null); - // init sensoryMemState - MapStateDescriptor sensoryMemStateDescriptor = - new MapStateDescriptor<>( - "sensoryMemory", - TypeInformation.of(String.class), - TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); - sensoryMemState = getRuntimeContext().getMapState(sensoryMemStateDescriptor); - - // init shortTermMemState - MapStateDescriptor shortTermMemStateDescriptor = - new MapStateDescriptor<>( - "shortTermMemory", - TypeInformation.of(String.class), - TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); - shortTermMemState = getRuntimeContext().getMapState(shortTermMemStateDescriptor); + + stateManager.initializeKeyedStates(getRuntimeContext()); + stateManager.initializeOperatorStates(getOperatorStateBackend()); resourceCache = new ResourceCache(agentPlan.getResourceProviders()); metricGroup = new FlinkAgentsMetricGroupImpl(getMetricGroup()); builtInMetrics = new BuiltInMetrics(metricGroup, agentPlan); - keySegmentQueue = new SegmentedQueue(); + eventRouter.open(builtInMetrics); - maybeInitActionStateStore(); - - if (actionStateStore != null) { - // init recovery marker state for recovery marker persistence - recoveryMarkerOpState = - getOperatorStateBackend() - .getUnionListState( - new ListStateDescriptor<>( - RECOVERY_MARKER_STATE_NAME, - TypeInformation.of(Object.class))); - } - // init sequence number state for per key message ordering - sequenceNumberKState = - getRuntimeContext() - .getState( - new ValueStateDescriptor<>( - MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class)); - - // init agent processing related state - actionTasksKState = - getRuntimeContext() - .getListState( - new ListStateDescriptor<>( - "actionTasks", TypeInformation.of(ActionTask.class))); - pendingInputEventsKState = - getRuntimeContext() - .getListState( - new ListStateDescriptor<>( - PENDING_INPUT_EVENT_STATE_NAME, - TypeInformation.of(Event.class))); - // We use UnionList here to ensure that the task can access all keys after parallelism - // modifications. - // Subsequent steps {@link #tryResumeProcessActionTasks} will then filter out keys that do - // not belong to the key range of current task. - currentProcessingKeysOpState = - getOperatorStateBackend() - .getUnionListState( - new ListStateDescriptor<>( - "currentProcessingKeys", TypeInformation.of(Object.class))); + durableExecManager.maybeInitActionStateStore(agentPlan.getConfig()); + durableExecManager.initRecoveryMarkerState(getOperatorStateBackend()); // init PythonActionExecutor and PythonResourceAdapter - initPythonEnvironment(); - - // init executor for Java async execution - continuationActionExecutor = - new ContinuationActionExecutor( + pythonBridge = new PythonBridgeManager(); + pythonBridge.open( + agentPlan, + resourceCache, + getExecutionConfig(), + getRuntimeContext().getDistributedCache(), + getContainingTask().getEnvironment().getTaskManagerInfo().getTmpDirectories(), + getRuntimeContext().getJobInfo().getJobId(), + metricGroup, + this::checkMailboxThread, + jobIdentifier); + + // init context manager for runner context creation and memory contexts + contextManager = + new ActionTaskContextManager( agentPlan.getConfig().get(AgentExecutionOptions.NUM_ASYNC_THREADS)); mailboxProcessor = getMailboxProcessor(); // Initialize the event logger if it is set. - initEventLogger(getRuntimeContext()); + eventRouter.initEventLogger(getRuntimeContext()); // Since an operator restart may change the key range it manages due to changes in // parallelism, @@ -337,17 +176,10 @@ public void open() throws Exception { tryResumeProcessActionTasks(); } - private void initEventLogger(StreamingRuntimeContext runtimeContext) throws Exception { - if (eventLogger == null) { - return; - } - eventLogger.open(new EventLoggerOpenParams(runtimeContext)); - } - @Override public void processWatermark(Watermark mark) throws Exception { - keySegmentQueue.addWatermark(mark); - processEligibleWatermarks(); + eventRouter.getKeySegmentQueue().addWatermark(mark); + eventRouter.processEligibleWatermarks(super::processWatermark); } @Override @@ -356,18 +188,19 @@ public void processElement(StreamRecord record) throws Exception { LOG.debug("Receive an element {}", input); // wrap to InputEvent first - Event inputEvent = wrapToInputEvent(input); + Event inputEvent = + eventRouter.wrapToInputEvent(input, pythonBridge.getPythonActionExecutor()); if (record.hasTimestamp()) { inputEvent.setSourceTimestamp(record.getTimestamp()); } - keySegmentQueue.addKeyToLastSegment(getCurrentKey()); + eventRouter.getKeySegmentQueue().addKeyToLastSegment(getCurrentKey()); - if (currentKeyHasMoreActionTask()) { + if (stateManager.hasMoreActionTasks()) { // If there are already actions being processed for the current key, the newly incoming // event should be queued and processed later. Therefore, we add it to // pendingInputEventsState. - pendingInputEventsKState.add(inputEvent); + stateManager.addPendingInputEvent(inputEvent); } else { // Otherwise, the new event is processed immediately. processEvent(getCurrentKey(), inputEvent); @@ -379,30 +212,35 @@ public void processElement(StreamRecord record) throws Exception { * `tryProcessActionTaskForKey` to continue processing. */ private void processEvent(Object key, Event event) throws Exception { - notifyEventProcessed(event); + eventRouter.notifyEventProcessed(event); boolean isInputEvent = EventUtil.isInputEvent(event); if (EventUtil.isOutputEvent(event)) { // If the event is an OutputEvent, we send it downstream. - OUT outputData = getOutputFromOutputEvent(event); + OUT outputData = + eventRouter.getOutputFromOutputEvent( + event, pythonBridge.getPythonActionExecutor()); if (event.hasSourceTimestamp()) { - output.collect(reusedStreamRecord.replace(outputData, event.getSourceTimestamp())); + output.collect( + eventRouter + .getReusedStreamRecord() + .replace(outputData, event.getSourceTimestamp())); } else { - reusedStreamRecord.eraseTimestamp(); - output.collect(reusedStreamRecord.replace(outputData)); + eventRouter.getReusedStreamRecord().eraseTimestamp(); + output.collect(eventRouter.getReusedStreamRecord().replace(outputData)); } } else { if (isInputEvent) { // If the event is an InputEvent, we mark that the key is currently being processed. - currentProcessingKeysOpState.add(key); - initOrIncSequenceNumber(); + stateManager.addProcessingKey(key); + stateManager.initOrIncSequenceNumber(); } // We then obtain the triggered action and add ActionTasks to the waiting processing // queue. - List triggerActions = getActionsTriggeredBy(event); + List triggerActions = eventRouter.getActionsTriggeredBy(event, agentPlan); if (triggerActions != null && !triggerActions.isEmpty()) { for (Action triggerAction : triggerActions) { - actionTasksKState.add(createActionTask(key, triggerAction, event)); + stateManager.addActionTask(createActionTask(key, triggerAction, event)); } } } @@ -413,25 +251,6 @@ private void processEvent(Object key, Event event) throws Exception { } } - private void notifyEventProcessed(Event event) throws Exception { - EventContext eventContext = new EventContext(event); - if (eventLogger != null) { - // If event logging is enabled, we log the event along with its context. - eventLogger.append(eventContext, event); - // For now, we flush the event logger after each event to ensure immediate logging. - // This is a temporary solution to ensure that events are logged immediately. - // TODO: In the future, we may want to implement a more efficient batching mechanism. - eventLogger.flush(); - } - if (eventListeners != null) { - // Notify all registered event listeners about the event. - for (EventListener listener : eventListeners) { - listener.onEventProcessed(eventContext, event); - } - } - builtInMetrics.markEventProcessed(); - } - private void tryProcessActionTaskForKey(Object key) { try { processActionTaskForKey(key); @@ -449,9 +268,9 @@ private void processActionTaskForKey(Object key) throws Exception { // 1. Get an action task for the key. setCurrentKey(key); - ActionTask actionTask = pollFromListState(actionTasksKState); + ActionTask actionTask = stateManager.pollNextActionTask(); if (actionTask == null) { - int removedCount = removeFromListState(currentProcessingKeysOpState, key); + int removedCount = stateManager.removeProcessingKey(key); checkState( removedCount == 1, "Current processing key count for key " @@ -459,21 +278,32 @@ private void processActionTaskForKey(Object key) throws Exception { + " should be 1, but got " + removedCount); checkState( - keySegmentQueue.removeKey(key), + eventRouter.getKeySegmentQueue().removeKey(key), "Current key" + key + " is missing from the segmentedQueue."); - processEligibleWatermarks(); + eventRouter.processEligibleWatermarks(super::processWatermark); return; } // 2. Invoke the action task. - createAndSetRunnerContext(actionTask, key); - - long sequenceNumber = sequenceNumberKState.value(); + contextManager.createAndSetRunnerContext( + actionTask, + key, + agentPlan, + resourceCache, + metricGroup, + jobIdentifier, + this::checkMailboxThread, + stateManager.getSensoryMemState(), + stateManager.getShortTermMemState(), + pythonBridge.getPythonRunnerContext()); + + long sequenceNumber = stateManager.getSequenceNumber(); boolean isFinished; List outputEvents; Optional generatedActionTaskOpt = Optional.empty(); ActionState actionState = - maybeGetActionState(key, sequenceNumber, actionTask.action, actionTask.event); + durableExecManager.maybeGetActionState( + key, sequenceNumber, actionTask.action, actionTask.event); // Check if action is already completed if (actionState != null && actionState.isCompleted()) { @@ -500,28 +330,30 @@ private void processActionTaskForKey(Object key) throws Exception { } else { // Initialize ActionState if not exists, or use existing one for recovery if (actionState == null) { - maybeInitActionState(key, sequenceNumber, actionTask.action, actionTask.event); + durableExecManager.maybeInitActionState( + key, sequenceNumber, actionTask.action, actionTask.event); actionState = - maybeGetActionState( + durableExecManager.maybeGetActionState( key, sequenceNumber, actionTask.action, actionTask.event); } // Set up durable execution context for fine-grained recovery - setupDurableExecutionContext(actionTask, actionState); + durableExecManager.setupDurableExecutionContext( + actionTask, actionState, sequenceNumber); ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke( getRuntimeContext().getUserCodeClassLoader(), - this.pythonActionExecutor); + this.pythonBridge.getPythonActionExecutor()); // We remove the contexts from the map after the task is processed. They will be added // back later if the action task has a generated action task, meaning it is not // finished. - actionTaskMemoryContexts.remove(actionTask); - actionTaskDurableContexts.remove(actionTask); - continuationContexts.remove(actionTask); - pythonAwaitableRefs.remove(actionTask); - maybePersistTaskResult( + contextManager.removeMemoryContext(actionTask); + durableExecManager.removeDurableContext(actionTask); + contextManager.removeContinuationContext(actionTask); + contextManager.removePythonAwaitableRef(actionTask); + durableExecManager.maybePersistTaskResult( key, sequenceNumber, actionTask.action, @@ -540,7 +372,7 @@ private void processActionTaskForKey(Object key) throws Exception { boolean currentInputEventFinished = false; if (isFinished) { builtInMetrics.markActionExecuted(actionTask.action.getName()); - currentInputEventFinished = !currentKeyHasMoreActionTask(); + currentInputEventFinished = !stateManager.hasMoreActionTasks(); // Persist memory to the Flink state when the action task is finished. actionTask.getRunnerContext().persistMemory(); @@ -555,29 +387,9 @@ private void processActionTaskForKey(Object key) throws Exception { // If the action task is not finished, we keep the contexts in memory for the // next generated ActionTask to be invoked. - actionTaskMemoryContexts.put( - generatedActionTask, actionTask.getRunnerContext().getMemoryContext()); - RunnerContextImpl.DurableExecutionContext durableContext = - actionTask.getRunnerContext().getDurableExecutionContext(); - if (durableContext != null) { - actionTaskDurableContexts.put(generatedActionTask, durableContext); - } - if (actionTask.getRunnerContext() instanceof JavaRunnerContextImpl) { - continuationContexts.put( - generatedActionTask, - ((JavaRunnerContextImpl) actionTask.getRunnerContext()) - .getContinuationContext()); - } - if (actionTask.getRunnerContext() instanceof PythonRunnerContextImpl) { - String awaitableRef = - ((PythonRunnerContextImpl) actionTask.getRunnerContext()) - .getPythonAwaitableRef(); - if (awaitableRef != null) { - pythonAwaitableRefs.put(generatedActionTask, awaitableRef); - } - } + contextManager.transferContexts(actionTask, generatedActionTask, durableExecManager); - actionTasksKState.add(generatedActionTask); + stateManager.addActionTask(generatedActionTask); } // 3. Process the next InputEvent or next action task @@ -587,8 +399,8 @@ private void processActionTaskForKey(Object key) throws Exception { // Once all sub-events and actions related to the current InputEvent are completed, // we can proceed to process the next InputEvent. - int removedCount = removeFromListState(currentProcessingKeysOpState, key); - maybePruneState(key, sequenceNumber); + int removedCount = stateManager.removeProcessingKey(key); + durableExecManager.maybePruneState(key, sequenceNumber); checkState( removedCount == 1, "Current processing key count for key " @@ -596,108 +408,20 @@ private void processActionTaskForKey(Object key) throws Exception { + " should be 1, but got " + removedCount); checkState( - keySegmentQueue.removeKey(key), + eventRouter.getKeySegmentQueue().removeKey(key), "Current key" + key + " is missing from the segmentedQueue."); - processEligibleWatermarks(); - Event pendingInputEvent = pollFromListState(pendingInputEventsKState); + eventRouter.processEligibleWatermarks(super::processWatermark); + Event pendingInputEvent = stateManager.pollNextPendingInputEvent(); if (pendingInputEvent != null) { processEvent(key, pendingInputEvent); } - } else if (currentKeyHasMoreActionTask()) { + } else if (stateManager.hasMoreActionTasks()) { // If the current key has additional action tasks remaining, we should submit a new mail // to continue processing them. mailboxExecutor.submit(() -> tryProcessActionTaskForKey(key), "process action task"); } } - private Resource getResource(String name, ResourceType type) { - try { - return resourceCache.getResource(name, type); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - private void initPythonEnvironment() throws Exception { - boolean containPythonAction = - agentPlan.getActions().values().stream() - .anyMatch(action -> action.getExec() instanceof PythonFunction); - - boolean containPythonResource = - agentPlan.getResourceProviders().values().stream() - .anyMatch( - resourceProviderMap -> - resourceProviderMap.values().stream() - .anyMatch( - resourceProvider -> - resourceProvider - instanceof - PythonResourceProvider)); - - if (containPythonAction || containPythonResource) { - LOG.debug("Begin initialize PythonEnvironmentManager."); - PythonDependencyInfo dependencyInfo = - PythonDependencyInfo.create( - getExecutionConfig().toConfiguration(), - getRuntimeContext().getDistributedCache()); - pythonEnvironmentManager = - new PythonEnvironmentManager( - dependencyInfo, - getContainingTask() - .getEnvironment() - .getTaskManagerInfo() - .getTmpDirectories(), - new HashMap<>(System.getenv()), - getRuntimeContext().getJobInfo().getJobId()); - pythonEnvironmentManager.open(); - EmbeddedPythonEnvironment env = pythonEnvironmentManager.createEnvironment(); - pythonInterpreter = env.getInterpreter(); - pythonRunnerContext = - new PythonRunnerContextImpl( - this.metricGroup, - this::checkMailboxThread, - this.agentPlan, - this.resourceCache, - this.jobIdentifier); - - javaResourceAdapter = new JavaResourceAdapter(this::getResource, pythonInterpreter); - if (containPythonResource) { - initPythonResourceAdapter(); - } - if (containPythonAction) { - initPythonActionExecutor(); - } - } - } - - private void initPythonActionExecutor() throws Exception { - pythonActionExecutor = - new PythonActionExecutor( - pythonInterpreter, - agentPlan, - javaResourceAdapter, - pythonRunnerContext, - jobIdentifier); - pythonActionExecutor.open(); - } - - private void initPythonResourceAdapter() throws Exception { - pythonResourceAdapter = - new PythonResourceAdapterImpl( - (String anotherName, ResourceType anotherType) -> { - try { - return resourceCache.getResource(anotherName, anotherType); - } catch (Exception e) { - throw new RuntimeException(e); - } - }, - pythonInterpreter, - javaResourceAdapter); - pythonResourceAdapter.open(); - PythonMCPResourceDiscovery.discoverPythonMCPResources( - agentPlan.getResourceProviders(), pythonResourceAdapter, resourceCache); - } - @Override public void endInput() throws Exception { waitInFlightEventsFinished(); @@ -705,7 +429,7 @@ public void endInput() throws Exception { @VisibleForTesting public void waitInFlightEventsFinished() throws Exception { - while (listStateNotEmpty(currentProcessingKeysOpState)) { + while (stateManager.hasProcessingKeys()) { mailboxExecutor.yield(); } } @@ -716,30 +440,17 @@ public void close() throws Exception { if (resourceCache != null) { resourceCache.close(); } - if (runnerContext != null) { - try { - runnerContext.close(); - } finally { - runnerContext = null; - } - } - if (pythonActionExecutor != null) { - pythonActionExecutor.close(); - } - if (pythonInterpreter != null) { - pythonInterpreter.close(); - } - if (pythonEnvironmentManager != null) { - pythonEnvironmentManager.close(); + if (contextManager != null) { + contextManager.close(); } - if (eventLogger != null) { - eventLogger.close(); + if (pythonBridge != null) { + pythonBridge.close(); } - if (actionStateStore != null) { - actionStateStore.close(); + if (eventRouter != null) { + eventRouter.close(); } - if (continuationActionExecutor != null) { - continuationActionExecutor.close(); + if (durableExecManager != null) { + durableExecManager.close(); } super.close(); @@ -749,32 +460,16 @@ public void close() throws Exception { public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - maybeInitActionStateStore(); - - if (actionStateStore != null) { - List markers = new ArrayList<>(); - - // We use UnionList here to ensure that the task can access all the recovery marker - // after - // parallelism modifications. - // The ActionStateStore will decide how to use the recovery markers. - ListState recoveryMarkerOpState = - getOperatorStateBackend() - .getUnionListState( - new ListStateDescriptor<>( - RECOVERY_MARKER_STATE_NAME, - TypeInformation.of(Object.class))); - - Iterable recoveryMarkers = recoveryMarkerOpState.get(); - if (recoveryMarkers != null) { - recoveryMarkers.forEach(markers::add); - } - LOG.info("Rebuilding action state from {} recovery markers", markers.size()); - actionStateStore.rebuildState(markers); - } + durableExecManager.maybeInitActionStateStore(agentPlan.getConfig()); + durableExecManager.handleRecovery(getOperatorStateBackend()); + + stateManager = new OperatorStateManager(); - // Get job identifier from user configuration. - // If not configured, get from state. + // Resolve the agent's stable job identifier: + // - If the user set it via AgentConfigOptions.JOB_IDENTIFIER, use that. + // - Otherwise fall back to the current Flink JobID, cached in operator + // state so the value remains stable across job restarts (Flink + // generates a fresh JobID on each restart). jobIdentifier = agentPlan.getConfig().get(JOB_IDENTIFIER); if (jobIdentifier == null) { String initialJobIdentifier = getRuntimeContext().getJobInfo().getJobId().toString(); @@ -786,71 +481,21 @@ public void initializeState(StateInitializationContext context) throws Exception @Override public void snapshotState(StateSnapshotContext context) throws Exception { - if (actionStateStore != null) { - Object recoveryMarker = actionStateStore.getRecoveryMarker(); - if (recoveryMarker != null) { - recoveryMarkerOpState.update(List.of(recoveryMarker)); - } - } + durableExecManager.snapshotRecoveryMarker(); - HashMap keyToSeqNum = new HashMap<>(); - getKeyedStateBackend() - .applyToAllKeys( - VoidNamespace.INSTANCE, - VoidNamespaceSerializer.INSTANCE, - new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class), - (key, state) -> keyToSeqNum.put(key, state.value())); - checkpointIdToSeqNums.put(context.getCheckpointId(), keyToSeqNum); + Map keyToSeqNum = + stateManager.snapshotSequenceNumbers(getKeyedStateBackend()); + durableExecManager.recordCheckpointSequenceNumbers(context.getCheckpointId(), keyToSeqNum); super.snapshotState(context); } @Override public void notifyCheckpointComplete(long checkpointId) throws Exception { - if (actionStateStore != null) { - Map keyToSeqNum = - checkpointIdToSeqNums.getOrDefault(checkpointId, new HashMap<>()); - for (Map.Entry entry : keyToSeqNum.entrySet()) { - actionStateStore.pruneState(entry.getKey(), entry.getValue()); - } - checkpointIdToSeqNums.remove(checkpointId); - } + durableExecManager.notifyCheckpointComplete(checkpointId); super.notifyCheckpointComplete(checkpointId); } - private Event wrapToInputEvent(IN input) { - if (inputIsJava) { - return new InputEvent(input); - } else { - // the input data must originate from Python and be of type Row with two fields — the - // first representing the key, and the second representing the actual data payload. - checkState(input instanceof Row && ((Row) input).getArity() == 2); - return pythonActionExecutor.wrapToInputEvent(((Row) input).getField(1)); - } - } - - private OUT getOutputFromOutputEvent(Event event) { - checkState(EventUtil.isOutputEvent(event)); - if (event instanceof OutputEvent) { - return (OUT) ((OutputEvent) event).getOutput(); - } else if (event instanceof PythonEvent) { - Object outputFromOutputEvent = - pythonActionExecutor.getOutputFromOutputEvent(((PythonEvent) event).getEvent()); - return (OUT) outputFromOutputEvent; - } else { - throw new IllegalStateException( - "Unsupported event type: " + event.getClass().getName()); - } - } - - private List getActionsTriggeredBy(Event event) { - if (event instanceof PythonEvent) { - return agentPlan.getActionsTriggeredBy(((PythonEvent) event).getEventType()); - } else { - return agentPlan.getActionsTriggeredBy(event.getClass().getName()); - } - } - private MailboxProcessor getMailboxProcessor() throws Exception { Field field = MailboxExecutorImpl.class.getDeclaredField("mailboxProcessor"); field.setAccessible(true); @@ -874,287 +519,48 @@ private ActionTask createActionTask(Object key, Action action, Event event) { } } - private void createAndSetRunnerContext(ActionTask actionTask, Object key) { - RunnerContextImpl runnerContext; - if (actionTask.action.getExec() instanceof JavaFunction) { - runnerContext = createOrGetRunnerContext(true); - } else if (actionTask.action.getExec() instanceof PythonFunction) { - runnerContext = createOrGetRunnerContext(false); - } else { - throw new IllegalStateException( - "Unsupported action type: " + actionTask.action.getExec().getClass()); - } - - RunnerContextImpl.MemoryContext memoryContext; - if (actionTaskMemoryContexts.containsKey(actionTask)) { - // action task for async execution action, should retrieve intermediate results from - // map. - memoryContext = actionTaskMemoryContexts.get(actionTask); - } else { - memoryContext = - new RunnerContextImpl.MemoryContext( - new CachedMemoryStore(sensoryMemState), - new CachedMemoryStore(shortTermMemState)); - } - - runnerContext.switchActionContext( - actionTask.action.getName(), memoryContext, String.valueOf(key.hashCode())); - - if (runnerContext instanceof JavaRunnerContextImpl) { - ContinuationContext continuationContext; - if (continuationContexts.containsKey(actionTask)) { - // action task for async execution action, should retrieve intermediate results from - // map. - continuationContext = continuationContexts.get(actionTask); - } else { - continuationContext = new ContinuationContext(); - } - ((JavaRunnerContextImpl) runnerContext).setContinuationContext(continuationContext); - } - if (runnerContext instanceof PythonRunnerContextImpl) { - // Get the awaitable ref from the transient map. After checkpoint restore, this will be - // null, signaling that the awaitable was lost and needs re-execution. - String awaitableRef = pythonAwaitableRefs.get(actionTask); - ((PythonRunnerContextImpl) runnerContext).setPythonAwaitableRef(awaitableRef); - } - actionTask.setRunnerContext(runnerContext); - } - - private boolean currentKeyHasMoreActionTask() throws Exception { - return listStateNotEmpty(actionTasksKState); - } - private void tryResumeProcessActionTasks() throws Exception { - Iterable keys = currentProcessingKeysOpState.get(); + Iterable keys = stateManager.getProcessingKeys(); if (keys != null) { int maxParallelism = getRuntimeContext().getTaskInfo().getMaxNumberOfParallelSubtasks(); KeyGroupRange currentSubtaskKeyGroupRange = - getCurrentSubtaskKeyGroupRange(maxParallelism); + stateManager.getCurrentSubtaskKeyGroupRange( + maxParallelism, getRuntimeContext()); for (Object key : keys) { - if (!isKeyOwnedByCurrentSubtask(key, maxParallelism, currentSubtaskKeyGroupRange)) { + if (!stateManager.isKeyOwnedByCurrentSubtask( + key, maxParallelism, currentSubtaskKeyGroupRange)) { continue; } - keySegmentQueue.addKeyToLastSegment(key); + eventRouter.getKeySegmentQueue().addKeyToLastSegment(key); mailboxExecutor.submit( () -> tryProcessActionTaskForKey(key), "process action task"); } } - getKeyedStateBackend() - .applyToAllKeys( - VoidNamespace.INSTANCE, - VoidNamespaceSerializer.INSTANCE, - new ListStateDescriptor<>( - PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)), - (key, state) -> - state.get() - .forEach( - event -> keySegmentQueue.addKeyToLastSegment(key))); - } - - private void initOrIncSequenceNumber() throws Exception { - // Initialize the sequence number state if it does not exist. - Long sequenceNumber = sequenceNumberKState.value(); - if (sequenceNumber == null) { - sequenceNumberKState.update(0L); - } else { - sequenceNumberKState.update(sequenceNumber + 1); - } - } - - private ActionState maybeGetActionState( - Object key, long sequenceNum, Action action, Event event) throws Exception { - return actionStateStore == null - ? null - : actionStateStore.get(key.toString(), sequenceNum, action, event); - } - - private void maybeInitActionState(Object key, long sequenceNum, Action action, Event event) - throws Exception { - if (actionStateStore != null) { - // Initialize the action state if it does not exist. It will exist when the action is an - // async action and - // has been persisted before the action task is finished. - if (actionStateStore.get(key, sequenceNum, action, event) == null) { - actionStateStore.put(key, sequenceNum, action, event, new ActionState(event)); - } - } - } - - private void maybePersistTaskResult( - Object key, - long sequenceNum, - Action action, - Event event, - RunnerContextImpl context, - ActionTask.ActionTaskResult actionTaskResult) - throws Exception { - if (actionStateStore == null) { - return; - } - - // if the task is not finished, we skip the persistence for now and wait until it is - // finished. - if (!actionTaskResult.isFinished()) { - return; - } - - ActionState actionState = actionStateStore.get(key, sequenceNum, action, event); - - for (MemoryUpdate memoryUpdate : context.getSensoryMemoryUpdates()) { - actionState.addSensoryMemoryUpdate(memoryUpdate); - } - - for (MemoryUpdate memoryUpdate : context.getShortTermMemoryUpdates()) { - actionState.addShortTermMemoryUpdate(memoryUpdate); - } - - for (Event outputEvent : actionTaskResult.getOutputEvents()) { - actionState.addEvent(outputEvent); - } - - // Mark the action as completed and clear call records - // This indicates that recovery should skip the entire action - actionState.markCompleted(); - - actionStateStore.put(key, sequenceNum, action, event, actionState); - - // Clear durable execution context - context.clearDurableExecutionContext(); + stateManager.forEachPendingInputEventKey( + getKeyedStateBackend(), + (key, state) -> + state.get() + .forEach( + event -> + eventRouter + .getKeySegmentQueue() + .addKeyToLastSegment(key))); } - /** - * Sets up the durable execution context for fine-grained recovery. - * - *

This method initializes the runner context with a {@link - * RunnerContextImpl.DurableExecutionContext}, which enables execute/execute_async calls to: - * - *

    - *
  • Skip re-execution for already completed calls during recovery - *
  • Persist CallRecords after each code block completion - *
- */ - private void setupDurableExecutionContext(ActionTask actionTask, ActionState actionState) { - if (actionStateStore == null) { - return; - } - - RunnerContextImpl.DurableExecutionContext durableContext; - if (actionTaskDurableContexts.containsKey(actionTask)) { - // Reuse existing context for async action continuation - durableContext = actionTaskDurableContexts.get(actionTask); - } else { - // Create new context for first invocation - final long sequenceNumber; - try { - sequenceNumber = sequenceNumberKState.value(); - } catch (Exception e) { - throw new RuntimeException("Failed to get sequence number from state", e); - } - - durableContext = - new RunnerContextImpl.DurableExecutionContext( - actionTask.getKey(), - sequenceNumber, - actionTask.action, - actionTask.event, - actionState, - this); - } - - actionTask.getRunnerContext().setDurableExecutionContext(durableContext); - } - - @Override - public void persist( - Object key, long sequenceNumber, Action action, Event event, ActionState actionState) { - try { - actionStateStore.put(key, sequenceNumber, action, event, actionState); - } catch (Exception e) { - LOG.error("Failed to persist ActionState", e); - throw new RuntimeException("Failed to persist ActionState", e); - } - } - - private void maybePruneState(Object key, long sequenceNum) throws Exception { - if (actionStateStore != null) { - actionStateStore.pruneState(key, sequenceNum); - } - } - - private void processEligibleWatermarks() throws Exception { - Watermark mark = keySegmentQueue.popOldestWatermark(); - while (mark != null) { - super.processWatermark(mark); - mark = keySegmentQueue.popOldestWatermark(); - } - } - - private RunnerContextImpl createOrGetRunnerContext(Boolean isJava) { - if (isJava) { - if (runnerContext == null) { - if (continuationActionExecutor == null) { - continuationActionExecutor = - new ContinuationActionExecutor( - agentPlan - .getConfig() - .get(AgentExecutionOptions.NUM_ASYNC_THREADS)); - } - runnerContext = - new JavaRunnerContextImpl( - this.metricGroup, - this::checkMailboxThread, - this.agentPlan, - this.resourceCache, - this.jobIdentifier, - continuationActionExecutor); - } - return runnerContext; - } else { - if (pythonRunnerContext == null) { - pythonRunnerContext = - new PythonRunnerContextImpl( - this.metricGroup, - this::checkMailboxThread, - this.agentPlan, - this.resourceCache, - jobIdentifier); - } - return pythonRunnerContext; - } - } - - private EventLogger createEventLogger(AgentPlan agentPlan) { - EventLoggerConfig.Builder loggerConfigBuilder = EventLoggerConfig.builder(); - String baseLogDir = agentPlan.getConfig().get(BASE_LOG_DIR); - if (baseLogDir != null && !baseLogDir.trim().isEmpty()) { - loggerConfigBuilder.property(FileEventLogger.BASE_LOG_DIR_PROPERTY_KEY, baseLogDir); - } - loggerConfigBuilder.property( - FileEventLogger.PRETTY_PRINT_PROPERTY_KEY, agentPlan.getConfig().get(PRETTY_PRINT)); - return EventLoggerFactory.createLogger(loggerConfigBuilder.build()); - } - - private void maybeInitActionStateStore() { - if (actionStateStore == null - && KAFKA.getType() - .equalsIgnoreCase(agentPlan.getConfig().get(ACTION_STATE_STORE_BACKEND))) { - LOG.info("Using Kafka as backend of action state store."); - actionStateStore = new KafkaActionStateStore(agentPlan.getConfig()); - } + @VisibleForTesting + DurableExecutionManager getDurableExecutionManager() { + return durableExecManager; } - private KeyGroupRange getCurrentSubtaskKeyGroupRange(int maxParallelism) { - int parallelism = getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks(); - int subtaskIndex = getRuntimeContext().getTaskInfo().getIndexOfThisSubtask(); - return KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( - maxParallelism, parallelism, subtaskIndex); + @VisibleForTesting + EventRouter getEventRouter() { + return eventRouter; } - private boolean isKeyOwnedByCurrentSubtask( - Object key, int maxParallelism, KeyGroupRange currentSubtaskKeyGroupRange) { - int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism); - return currentSubtaskKeyGroupRange.contains(keyGroup); + @VisibleForTesting + OperatorStateManager getOperatorStateManager() { + return stateManager; } /** Failed to execute Action task. */ diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java new file mode 100644 index 000000000..5ffc37d17 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManager.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.agents.plan.PythonFunction; +import org.apache.flink.agents.runtime.ResourceCache; +import org.apache.flink.agents.runtime.async.ContinuationActionExecutor; +import org.apache.flink.agents.runtime.async.ContinuationContext; +import org.apache.flink.agents.runtime.context.JavaRunnerContextImpl; +import org.apache.flink.agents.runtime.context.RunnerContextImpl; +import org.apache.flink.agents.runtime.memory.CachedMemoryStore; +import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; +import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; +import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; +import org.apache.flink.api.common.state.MapState; + +import javax.annotation.Nullable; + +import java.util.HashMap; +import java.util.Map; + +/** + * Owns the per-{@link ActionTask} runtime context bookkeeping for {@link ActionExecutionOperator}. + * + *

Owned state: + * + *

    + *
  • The shared (Java) {@link RunnerContextImpl} that is reused across action tasks via {@link + * RunnerContextImpl#switchActionContext}. + *
  • Three per-{@link ActionTask} maps that survive across the boundary between a finishing + * action and the action it generates: memory contexts, continuation contexts (for async Java + * actions), and Python awaitable references. + *
  • The {@link ContinuationActionExecutor} thread pool used to run async Java continuations. + *
+ * + *

Lifecycle: instantiated by the operator's {@code open()} with the configured async-thread + * count from the agent plan. Has no separate {@code open()} step — fully constructed in the + * operator's {@code open()}. {@link #close()} closes the shared runner context and the continuation + * executor. + * + *

Note: the Python {@link RunnerContextImpl} is not owned here — it is owned by {@link + * PythonBridgeManager} and passed in as a parameter to {@link #createOrGetRunnerContext} and {@link + * #createAndSetRunnerContext}. The durable-execution context map likewise lives on {@link + * DurableExecutionManager} and is accessed via the manager parameter passed to {@link + * #transferContexts}. + * + *

Design constraint: package-private; no manager-to-manager held references. Cross-cutting data + * flows via method parameters. + */ +class ActionTaskContextManager implements AutoCloseable { + + private RunnerContextImpl runnerContext; + + private final Map actionTaskMemoryContexts; + private final Map continuationContexts; + private final Map pythonAwaitableRefs; + + private ContinuationActionExecutor continuationActionExecutor; + + ActionTaskContextManager(int numAsyncThreads) { + this.actionTaskMemoryContexts = new HashMap<>(); + this.continuationContexts = new HashMap<>(); + this.pythonAwaitableRefs = new HashMap<>(); + this.continuationActionExecutor = new ContinuationActionExecutor(numAsyncThreads); + } + + /** + * Returns a runner context for an action's exec language. + * + *

For Java actions, lazily creates a single {@link JavaRunnerContextImpl} that is reused for + * every Java action. For Python actions, returns the supplied {@link PythonRunnerContextImpl} + * (owned by {@link PythonBridgeManager}). Throws {@link IllegalStateException} if a Python + * context is requested but none was provided, or if the continuation executor has not been + * initialized. + * + * @param isJava {@code true} if the action is a Java action, {@code false} if Python. + * @param agentPlan the agent plan, used when creating the Java runner context. + * @param resourceCache the resource cache, used when creating the Java runner context. + * @param metricGroup the agent metric group. + * @param jobIdentifier the job identifier. + * @param mailboxThreadChecker hook used by runner contexts to assert mailbox-thread access. + * @param pythonRunnerContext the pre-built Python runner context, or {@code null} for Java. + * @return the runner context appropriate for the action's exec language. + */ + RunnerContextImpl createOrGetRunnerContext( + boolean isJava, + AgentPlan agentPlan, + ResourceCache resourceCache, + FlinkAgentsMetricGroupImpl metricGroup, + String jobIdentifier, + Runnable mailboxThreadChecker, + PythonRunnerContextImpl pythonRunnerContext) { + if (isJava) { + if (runnerContext == null) { + if (continuationActionExecutor == null) { + throw new IllegalStateException( + "ContinuationActionExecutor has not been initialized."); + } + runnerContext = + new JavaRunnerContextImpl( + metricGroup, + mailboxThreadChecker, + agentPlan, + resourceCache, + jobIdentifier, + continuationActionExecutor); + } + return runnerContext; + } else { + if (pythonRunnerContext == null) { + throw new IllegalStateException( + "PythonRunnerContextImpl has not been initialized."); + } + return pythonRunnerContext; + } + } + + /** + * Resolves the runner context for the given action task, switches it to that task's action, and + * wires its memory, continuation, and Python-awaitable contexts. + * + *

Steps: + * + *

    + *
  1. Selects a Java or Python runner context based on the action's {@code Exec} type. + *
  2. Reuses any existing {@link RunnerContextImpl.MemoryContext} for this task; otherwise + * builds a fresh one backed by the supplied sensory/short-term memory states. + *
  3. Calls {@link RunnerContextImpl#switchActionContext} so the shared context now points at + * this action's name, memory, and key namespace. + *
  4. For Java contexts, attaches a continuation context (re-used if the task is resuming + * from an async suspend, fresh otherwise). + *
  5. For Python contexts, attaches the per-task awaitable reference (or {@code null} if the + * awaitable was lost across a checkpoint restore — the action will then re-execute). + *
+ * + * @param actionTask the task to be set up before execution. + * @param key the current Flink key. + * @param agentPlan the agent plan. + * @param resourceCache the resource cache. + * @param metricGroup the agent metric group. + * @param jobIdentifier the job identifier. + * @param mailboxThreadChecker hook used to assert mailbox-thread access from runner contexts. + * @param sensoryMemState keyed map state backing sensory memory. + * @param shortTermMemState keyed map state backing short-term memory. + * @param pythonRunnerContext the Python runner context, or {@code null} when no Python runtime + * is initialized. + */ + void createAndSetRunnerContext( + ActionTask actionTask, + Object key, + AgentPlan agentPlan, + ResourceCache resourceCache, + FlinkAgentsMetricGroupImpl metricGroup, + String jobIdentifier, + Runnable mailboxThreadChecker, + MapState sensoryMemState, + MapState shortTermMemState, + PythonRunnerContextImpl pythonRunnerContext) { + RunnerContextImpl context; + if (actionTask.action.getExec() instanceof JavaFunction) { + context = + createOrGetRunnerContext( + true, + agentPlan, + resourceCache, + metricGroup, + jobIdentifier, + mailboxThreadChecker, + pythonRunnerContext); + } else if (actionTask.action.getExec() instanceof PythonFunction) { + context = + createOrGetRunnerContext( + false, + agentPlan, + resourceCache, + metricGroup, + jobIdentifier, + mailboxThreadChecker, + pythonRunnerContext); + } else { + throw new IllegalStateException( + "Unsupported action type: " + actionTask.action.getExec().getClass()); + } + + RunnerContextImpl.MemoryContext memoryContext; + if (actionTaskMemoryContexts.containsKey(actionTask)) { + memoryContext = actionTaskMemoryContexts.get(actionTask); + } else { + memoryContext = + new RunnerContextImpl.MemoryContext( + new CachedMemoryStore(sensoryMemState), + new CachedMemoryStore(shortTermMemState)); + } + + context.switchActionContext( + actionTask.action.getName(), memoryContext, String.valueOf(key.hashCode())); + + if (context instanceof JavaRunnerContextImpl) { + ContinuationContext continuationContext; + if (this.hasContinuationContext(actionTask)) { + // action task for async execution action, should retrieve intermediate results + // from map. + continuationContext = this.getContinuationContext(actionTask); + } else { + continuationContext = new ContinuationContext(); + } + ((JavaRunnerContextImpl) context).setContinuationContext(continuationContext); + } + if (context instanceof PythonRunnerContextImpl) { + // Get the awaitable ref from the transient map. After checkpoint restore, this will + // be null, signaling that the awaitable was lost and needs re-execution. + String awaitableRef = this.getPythonAwaitableRef(actionTask); + ((PythonRunnerContextImpl) context).setPythonAwaitableRef(awaitableRef); + } + actionTask.setRunnerContext(context); + } + + @Nullable + RunnerContextImpl.MemoryContext getMemoryContext(ActionTask actionTask) { + return actionTaskMemoryContexts.get(actionTask); + } + + void putMemoryContext(ActionTask actionTask, RunnerContextImpl.MemoryContext memoryContext) { + actionTaskMemoryContexts.put(actionTask, memoryContext); + } + + @Nullable + RunnerContextImpl.MemoryContext removeMemoryContext(ActionTask actionTask) { + return actionTaskMemoryContexts.remove(actionTask); + } + + /** + * Transfers per-task contexts from a finishing action task to the action task it generated. + * + *

Always transfers the memory context. For Java tasks, transfers the continuation context. + * For Python tasks, transfers the awaitable reference when present. The durable-execution + * context map lives on {@link DurableExecutionManager}, so that manager is passed in as a + * parameter rather than held as a field — this keeps the no-manager-to-manager-references + * design constraint intact. + * + * @param fromTask the finishing task whose contexts should be transferred. + * @param toTask the newly generated task that will inherit the contexts. + * @param durableExecManager used to copy the durable-execution context entry, if any. + */ + void transferContexts( + ActionTask fromTask, ActionTask toTask, DurableExecutionManager durableExecManager) { + putMemoryContext(toTask, fromTask.getRunnerContext().getMemoryContext()); + RunnerContextImpl.DurableExecutionContext durableContext = + fromTask.getRunnerContext().getDurableExecutionContext(); + if (durableContext != null) { + durableExecManager.putDurableContext(toTask, durableContext); + } + if (fromTask.getRunnerContext() instanceof JavaRunnerContextImpl) { + this.putContinuationContext( + toTask, + ((JavaRunnerContextImpl) fromTask.getRunnerContext()).getContinuationContext()); + } + if (fromTask.getRunnerContext() instanceof PythonRunnerContextImpl) { + String awaitableRef = + ((PythonRunnerContextImpl) fromTask.getRunnerContext()).getPythonAwaitableRef(); + if (awaitableRef != null) { + this.putPythonAwaitableRef(toTask, awaitableRef); + } + } + } + + @Nullable + ContinuationContext getContinuationContext(ActionTask actionTask) { + return continuationContexts.get(actionTask); + } + + void putContinuationContext(ActionTask actionTask, ContinuationContext context) { + continuationContexts.put(actionTask, context); + } + + void removeContinuationContext(ActionTask actionTask) { + continuationContexts.remove(actionTask); + } + + boolean hasContinuationContext(ActionTask actionTask) { + return continuationContexts.containsKey(actionTask); + } + + @Nullable + String getPythonAwaitableRef(ActionTask actionTask) { + return pythonAwaitableRefs.get(actionTask); + } + + void putPythonAwaitableRef(ActionTask actionTask, String ref) { + pythonAwaitableRefs.put(actionTask, ref); + } + + void removePythonAwaitableRef(ActionTask actionTask) { + pythonAwaitableRefs.remove(actionTask); + } + + @Override + public void close() throws Exception { + if (runnerContext != null) { + try { + runnerContext.close(); + } finally { + runnerContext = null; + } + } + if (continuationActionExecutor != null) { + continuationActionExecutor.close(); + } + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java new file mode 100644 index 000000000..91b8a3351 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/DurableExecutionManager.java @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.context.MemoryUpdate; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.actionstate.ActionState; +import org.apache.flink.agents.runtime.actionstate.ActionStateStore; +import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore; +import org.apache.flink.agents.runtime.context.ActionStatePersister; +import org.apache.flink.agents.runtime.context.RunnerContextImpl; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND; +import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA; + +/** + * Owns the durable-execution side of {@link ActionExecutionOperator}: the optional {@link + * ActionStateStore}, the recovery-marker operator state, the per-checkpoint sequence-number map, + * and the per-{@link ActionTask} {@link RunnerContextImpl.DurableExecutionContext} map. + * + *

Durable mode is optional. If no {@link ActionStateStore} is configured (and none is + * pre-injected via the constructor), all {@code maybe*} methods are no-ops and {@link + * #hasDurableStore()} returns {@code false}. This lets the operator stay agnostic of whether + * durable execution is enabled. + * + *

Lifecycle: instantiated in the operator constructor. {@link + * #maybeInitActionStateStore(AgentConfiguration)} runs from BOTH the operator's {@code + * initializeState()} and {@code open()} — recovery requires the store to be configured before + * {@link #handleRecovery(OperatorStateBackend)} reads from it, and the {@code open()} call ensures + * the store is also available on the normal (non-recovery) path. The method creates a default + * Kafka-backed store when one was not pre-injected, and is idempotent on the second call. {@link + * #handleRecovery(OperatorStateBackend)} runs from the operator's {@code initializeState()} during + * recovery. {@link #initRecoveryMarkerState(OperatorStateBackend)} runs from the operator's {@code + * open()}. {@link #close()} closes the underlying store. + * + *

Design constraint: package-private; no manager-to-manager held references. Cross-cutting data + * flows via method parameters. In particular, {@link + * ActionTaskContextManager#transferContexts(ActionTask, ActionTask, DurableExecutionManager)} + * accepts this manager as a parameter so that the durable-context map can stay encapsulated here. + */ +class DurableExecutionManager implements ActionStatePersister, AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(DurableExecutionManager.class); + + private static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker"; + + private ActionStateStore actionStateStore; + private transient ListState recoveryMarkerOpState; + private final Map> checkpointIdToSeqNums; + + private final Map + actionTaskDurableContexts; + + /** + * @param actionStateStore an optional pre-injected store, primarily for tests. When {@code + * null}, {@link #maybeInitActionStateStore(AgentConfiguration)} may create a default store + * based on configuration; otherwise durable execution is disabled. + */ + DurableExecutionManager(@Nullable ActionStateStore actionStateStore) { + this.actionStateStore = actionStateStore; + this.checkpointIdToSeqNums = new HashMap<>(); + this.actionTaskDurableContexts = new HashMap<>(); + } + + /** + * Lazily creates a default {@link ActionStateStore} from configuration if none was + * pre-injected. + * + *

Only creates a store when this manager was constructed without one and the configuration + * selects a recognized backend (currently Kafka). Otherwise this is a no-op, which leaves + * durable execution disabled. + * + * @param config the agent configuration carrying the backend selection. + */ + void maybeInitActionStateStore(AgentConfiguration config) { + if (actionStateStore == null + && KAFKA.getType().equalsIgnoreCase(config.get(ACTION_STATE_STORE_BACKEND))) { + LOG.info("Using Kafka as backend of action state store."); + actionStateStore = new KafkaActionStateStore(config); + } + } + + boolean hasDurableStore() { + return actionStateStore != null; + } + + void initRecoveryMarkerState(OperatorStateBackend operatorStateBackend) throws Exception { + if (actionStateStore != null) { + recoveryMarkerOpState = + operatorStateBackend.getUnionListState( + new ListStateDescriptor<>( + RECOVERY_MARKER_STATE_NAME, TypeInformation.of(Object.class))); + } + } + + /** + * Replays recovery markers from the operator's union-list state to rebuild durable action + * state. + * + *

Called from the operator's {@code initializeState()}, which runs before {@code open()}. + * This means {@link #recoveryMarkerOpState} is not yet initialized, so the union-list state + * descriptor is re-created here using the same descriptor name — Flink returns the same + * underlying state. No-op when durable execution is disabled. + * + * @param operatorStateBackend the operator state backend used to obtain the recovery-marker + * union-list state. + */ + void handleRecovery(OperatorStateBackend operatorStateBackend) throws Exception { + if (actionStateStore != null) { + List markers = new ArrayList<>(); + ListState markerState = + operatorStateBackend.getUnionListState( + new ListStateDescriptor<>( + RECOVERY_MARKER_STATE_NAME, TypeInformation.of(Object.class))); + Iterable recoveryMarkers = markerState.get(); + if (recoveryMarkers != null) { + recoveryMarkers.forEach(markers::add); + } + LOG.info("Rebuilding action state from {} recovery markers", markers.size()); + actionStateStore.rebuildState(markers); + } + } + + @Nullable + ActionState maybeGetActionState(Object key, long sequenceNum, Action action, Event event) + throws Exception { + return actionStateStore == null + ? null + : actionStateStore.get(key.toString(), sequenceNum, action, event); + } + + void maybeInitActionState(Object key, long sequenceNum, Action action, Event event) + throws Exception { + if (actionStateStore != null) { + if (actionStateStore.get(key, sequenceNum, action, event) == null) { + actionStateStore.put(key, sequenceNum, action, event, new ActionState(event)); + } + } + } + + /** + * Persists the result of a finished {@link ActionTask} to the durable store. + * + *

No-op when no store is configured or when the task did not finish (e.g. it suspended on a + * continuation). On finish, accumulates the task's memory updates and emitted output events + * into the {@link ActionState}, marks it completed, persists it, and clears the in-context + * durable bookkeeping. + * + * @param key the key under which the action ran. + * @param sequenceNum the per-key message sequence number. + * @param action the action being persisted. + * @param event the input event that triggered this action. + * @param context the runner context whose memory updates will be folded into the action state. + * @param actionTaskResult the result of running the action task. + */ + void maybePersistTaskResult( + Object key, + long sequenceNum, + Action action, + Event event, + RunnerContextImpl context, + ActionTask.ActionTaskResult actionTaskResult) + throws Exception { + if (actionStateStore == null) { + return; + } + + if (!actionTaskResult.isFinished()) { + return; + } + + ActionState actionState = actionStateStore.get(key, sequenceNum, action, event); + + for (MemoryUpdate memoryUpdate : context.getSensoryMemoryUpdates()) { + actionState.addSensoryMemoryUpdate(memoryUpdate); + } + + for (MemoryUpdate memoryUpdate : context.getShortTermMemoryUpdates()) { + actionState.addShortTermMemoryUpdate(memoryUpdate); + } + + for (Event outputEvent : actionTaskResult.getOutputEvents()) { + actionState.addEvent(outputEvent); + } + + actionState.markCompleted(); + + actionStateStore.put(key, sequenceNum, action, event, actionState); + + context.clearDurableExecutionContext(); + } + + /** + * Wires a {@link RunnerContextImpl.DurableExecutionContext} onto the given action task's runner + * context. + * + *

Returns immediately when no durable store is configured. Otherwise reuses an existing + * {@link RunnerContextImpl.DurableExecutionContext} held in the per-task map (i.e. when + * resuming a continuation), or creates a fresh one bound to this manager so that nested + * persists route back through {@link #persist}. + * + * @param actionTask the action task to attach the context to. + * @param actionState the action state for this (key, sequenceNum, action, event). + * @param seqNum the per-key sequence number. + */ + void setupDurableExecutionContext(ActionTask actionTask, ActionState actionState, long seqNum) { + if (actionStateStore == null) { + return; + } + + RunnerContextImpl.DurableExecutionContext durableContext; + if (actionTaskDurableContexts.containsKey(actionTask)) { + durableContext = actionTaskDurableContexts.get(actionTask); + } else { + durableContext = + new RunnerContextImpl.DurableExecutionContext( + actionTask.getKey(), + seqNum, + actionTask.action, + actionTask.event, + actionState, + this); + } + + actionTask.getRunnerContext().setDurableExecutionContext(durableContext); + } + + @Override + public void persist( + Object key, long sequenceNumber, Action action, Event event, ActionState actionState) { + try { + actionStateStore.put(key, sequenceNumber, action, event, actionState); + } catch (Exception e) { + LOG.error("Failed to persist ActionState", e); + throw new RuntimeException("Failed to persist ActionState", e); + } + } + + void maybePruneState(Object key, long sequenceNum) throws Exception { + if (actionStateStore != null) { + actionStateStore.pruneState(key, sequenceNum); + } + } + + /** + * Prunes durable state for all per-key sequence numbers that were captured at the time of the + * given checkpoint. + * + *

The mapping from checkpoint id to per-key sequence numbers must have been recorded earlier + * via {@link #recordCheckpointSequenceNumbers}. After pruning, the entry for that checkpoint is + * removed. No-op when durable execution is disabled. + * + * @param checkpointId the id of the completed checkpoint. + */ + void notifyCheckpointComplete(long checkpointId) { + if (actionStateStore != null) { + Map keyToSeqNum = + checkpointIdToSeqNums.getOrDefault(checkpointId, new HashMap<>()); + for (Map.Entry entry : keyToSeqNum.entrySet()) { + actionStateStore.pruneState(entry.getKey(), entry.getValue()); + } + checkpointIdToSeqNums.remove(checkpointId); + } + } + + void snapshotRecoveryMarker() throws Exception { + if (actionStateStore != null) { + Object recoveryMarker = actionStateStore.getRecoveryMarker(); + if (recoveryMarker != null) { + recoveryMarkerOpState.update(List.of(recoveryMarker)); + } + } + } + + /** + * Records the per-key sequence numbers observed when snapshotting the given checkpoint. + * + *

The recorded mapping is consulted later by {@link #notifyCheckpointComplete(long)} to + * prune durable state up to the sequence number that was committed by that checkpoint. + * + * @param checkpointId the checkpoint being snapshotted. + * @param seqNums the per-key sequence numbers captured during the snapshot. + */ + void recordCheckpointSequenceNumbers(long checkpointId, Map seqNums) { + checkpointIdToSeqNums.put(checkpointId, seqNums); + } + + // --- Durable execution context map accessors --- + + @Nullable + RunnerContextImpl.DurableExecutionContext getDurableContext(ActionTask actionTask) { + return actionTaskDurableContexts.get(actionTask); + } + + void putDurableContext( + ActionTask actionTask, RunnerContextImpl.DurableExecutionContext context) { + actionTaskDurableContexts.put(actionTask, context); + } + + void removeDurableContext(ActionTask actionTask) { + actionTaskDurableContexts.remove(actionTask); + } + + boolean hasDurableContext(ActionTask actionTask) { + return actionTaskDurableContexts.containsKey(actionTask); + } + + @VisibleForTesting + @Nullable + ActionStateStore getActionStateStore() { + return actionStateStore; + } + + @Override + public void close() throws Exception { + if (actionStateStore != null) { + actionStateStore.close(); + } + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java new file mode 100644 index 000000000..be3ea76dd --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.EventContext; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.listener.EventListener; +import org.apache.flink.agents.api.logger.EventLogger; +import org.apache.flink.agents.api.logger.EventLoggerConfig; +import org.apache.flink.agents.api.logger.EventLoggerFactory; +import org.apache.flink.agents.api.logger.EventLoggerOpenParams; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.eventlog.FileEventLogger; +import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; +import org.apache.flink.agents.runtime.operator.queue.SegmentedQueue; +import org.apache.flink.agents.runtime.python.event.PythonEvent; +import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; +import org.apache.flink.agents.runtime.utils.EventUtil; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.types.Row; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.BASE_LOG_DIR; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.PRETTY_PRINT; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Handles event-side concerns for {@link ActionExecutionOperator}: input/output transformation + * between Java/Python representations, action lookup against the {@link AgentPlan}, event-logger + * and event-listener notification, and watermark draining via the per-key segment queue. + * + *

Owned state: + * + *

    + *
  • The {@link EventLogger} created from the agent plan's logging configuration (may be {@code + * null} when logging is disabled). + *
  • The list of registered {@link EventListener}s. + *
  • A reused {@link StreamRecord} used to emit outputs without per-record allocation. + *
  • The {@link SegmentedQueue} that orders watermarks behind in-flight keys so a watermark is + * only emitted once all keys ahead of it have finished. + *
  • The late-bound {@link BuiltInMetrics} provided in {@link #open(BuiltInMetrics)}. + *
+ * + *

Lifecycle: instantiated in the operator constructor (which decides {@link #inputIsJava}). + * {@link #open(BuiltInMetrics)} runs from the operator's {@code open()} once metrics are available. + * {@link #initEventLogger} also runs from the operator's {@code open()} once the runtime context is + * available (after metrics have been built). {@link #close()} closes the event logger. + * + *

Design constraint: package-private; no manager-to-manager held references. + * + * @param input record type + * @param output record type + */ +class EventRouter implements AutoCloseable { + + private final boolean inputIsJava; + private final EventLogger eventLogger; + private final List eventListeners; + private StreamRecord reusedStreamRecord; + private SegmentedQueue keySegmentQueue; + private BuiltInMetrics builtInMetrics; + + EventRouter(AgentPlan agentPlan, boolean inputIsJava) { + this.inputIsJava = inputIsJava; + this.eventLogger = createEventLogger(agentPlan); + this.eventListeners = new ArrayList<>(); + } + + /** + * Initializes mutable runtime state that depends on metrics being available. + * + *

Allocates the reused stream record and the segmented watermark queue, and stores the + * supplied {@link BuiltInMetrics} for use in {@link #notifyEventProcessed(Event)}. Called from + * the operator's {@code open()} once metric groups are constructed. + * + * @param builtInMetrics the operator's built-in metrics handle. + */ + void open(BuiltInMetrics builtInMetrics) { + this.reusedStreamRecord = new StreamRecord<>(null); + this.keySegmentQueue = new SegmentedQueue(); + this.builtInMetrics = builtInMetrics; + } + + void initEventLogger(StreamingRuntimeContext runtimeContext) throws Exception { + if (eventLogger == null) { + return; + } + eventLogger.open(new EventLoggerOpenParams(runtimeContext)); + } + + /** + * Wraps an incoming record into an {@link Event} suitable for action dispatch. + * + *

Java pipelines wrap the raw input directly into a Java {@link InputEvent}. Python + * pipelines expect a two-field {@link Row} where the first field is the key and the second is + * the actual payload; the payload is converted to a Python event via the supplied {@link + * PythonActionExecutor}. + * + * @param input the raw input record. + * @param pythonActionExecutor the Python action executor (used only when input originates from + * Python). + * @return the wrapped input event. + */ + @SuppressWarnings("unchecked") + Event wrapToInputEvent(IN input, PythonActionExecutor pythonActionExecutor) { + if (inputIsJava) { + return new InputEvent(input); + } else { + // the input data must originate from Python and be of type Row with two fields — the + // first representing the key, and the second representing the actual data payload. + checkState(input instanceof Row && ((Row) input).getArity() == 2); + return pythonActionExecutor.wrapToInputEvent(((Row) input).getField(1)); + } + } + + /** + * Extracts the downstream output payload from an {@link OutputEvent}. + * + *

For a Java {@link OutputEvent}, returns the payload directly. For a Python {@link + * PythonEvent}, delegates to the supplied {@link PythonActionExecutor} to convert the Python + * output back into the Java output type. + * + * @param event the output event (must satisfy {@link EventUtil#isOutputEvent(Event)}). + * @param pythonActionExecutor the Python action executor (used only for Python events). + * @return the typed output payload. + * @throws IllegalStateException if the event is not a recognized output-event type. + */ + @SuppressWarnings("unchecked") + OUT getOutputFromOutputEvent(Event event, PythonActionExecutor pythonActionExecutor) { + checkState(EventUtil.isOutputEvent(event)); + if (event instanceof OutputEvent) { + return (OUT) ((OutputEvent) event).getOutput(); + } else if (event instanceof PythonEvent) { + Object outputFromOutputEvent = + pythonActionExecutor.getOutputFromOutputEvent(((PythonEvent) event).getEvent()); + return (OUT) outputFromOutputEvent; + } else { + throw new IllegalStateException( + "Unsupported event type: " + event.getClass().getName()); + } + } + + List getActionsTriggeredBy(Event event, AgentPlan agentPlan) { + if (event instanceof PythonEvent) { + return agentPlan.getActionsTriggeredBy(((PythonEvent) event).getEventType()); + } else { + return agentPlan.getActionsTriggeredBy(event.getClass().getName()); + } + } + + /** + * Notifies the configured event sinks (logger, listeners, metrics) that an event was processed. + * + *

If event logging is enabled, appends and immediately flushes the event. Then notifies + * every registered {@link EventListener}. Finally increments the {@code eventProcessed} + * built-in metric. The event logger is flushed per call as a temporary measure pending a + * batched flush mechanism. + * + * @param event the event that was just processed. + */ + void notifyEventProcessed(Event event) throws Exception { + EventContext eventContext = new EventContext(event); + if (eventLogger != null) { + // If event logging is enabled, we log the event along with its context. + eventLogger.append(eventContext, event); + // For now, we flush the event logger after each event to ensure immediate logging. + // This is a temporary solution to ensure that events are logged immediately. + // TODO: In the future, we may want to implement a more efficient batching mechanism. + eventLogger.flush(); + } + if (eventListeners != null) { + // Notify all registered event listeners about the event. + for (EventListener listener : eventListeners) { + listener.onEventProcessed(eventContext, event); + } + } + builtInMetrics.markEventProcessed(); + } + + /** + * Drains all watermarks from the segmented queue that are now eligible to be emitted. + * + *

A watermark becomes eligible once every key in the segment ahead of it has finished + * processing. This method pops watermarks in order and forwards each to the supplied {@link + * WatermarkEmitter}. + * + * @param watermarkEmitter callback that emits a watermark downstream. + */ + void processEligibleWatermarks(WatermarkEmitter watermarkEmitter) throws Exception { + Watermark mark = keySegmentQueue.popOldestWatermark(); + while (mark != null) { + watermarkEmitter.emit(mark); + mark = keySegmentQueue.popOldestWatermark(); + } + } + + SegmentedQueue getKeySegmentQueue() { + return keySegmentQueue; + } + + StreamRecord getReusedStreamRecord() { + return reusedStreamRecord; + } + + @VisibleForTesting + @Nullable + EventLogger getEventLogger() { + return eventLogger; + } + + private EventLogger createEventLogger(AgentPlan agentPlan) { + EventLoggerConfig.Builder loggerConfigBuilder = EventLoggerConfig.builder(); + String baseLogDir = agentPlan.getConfig().get(BASE_LOG_DIR); + if (baseLogDir != null && !baseLogDir.trim().isEmpty()) { + loggerConfigBuilder.property(FileEventLogger.BASE_LOG_DIR_PROPERTY_KEY, baseLogDir); + } + loggerConfigBuilder.property( + FileEventLogger.PRETTY_PRINT_PROPERTY_KEY, agentPlan.getConfig().get(PRETTY_PRINT)); + return EventLoggerFactory.createLogger(loggerConfigBuilder.build()); + } + + @Override + public void close() throws Exception { + if (eventLogger != null) { + eventLogger.close(); + } + } + + @FunctionalInterface + interface WatermarkEmitter { + void emit(Watermark mark) throws Exception; + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java new file mode 100644 index 000000000..785d43beb --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/OperatorStateManager.java @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.KeyedStateFunction; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; + +import javax.annotation.Nullable; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.agents.runtime.utils.StateUtil.*; + +/** + * Owns all Flink state used by {@link ActionExecutionOperator} and exposes a narrow API for + * accessing it. + * + *

Owned state: + * + *

    + *
  • Keyed list state of pending {@link ActionTask}s for the current key. + *
  • Keyed list state of pending {@link Event}s buffered while another input is processing. + *
  • Keyed value state holding the per-key message sequence number. + *
  • Keyed map states for sensory and short-term memory. + *
  • Operator union-list state of keys currently being processed (used after rescale to recover + * in-flight work). + *
+ * + *

Lifecycle: instantiated by the operator's {@code initializeState()} (the Flink lifecycle runs + * {@code initializeState} before {@code open}). Both {@link + * #initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext)} and {@link + * #initializeOperatorStates(OperatorStateBackend)} are invoked later from the operator's {@code + * open()}. There is no explicit close — the underlying state handles are owned by Flink. + * + *

Design constraint: package-private; no manager-to-manager held references. Cross-cutting data + * flows via method parameters (see for example {@link ActionTaskContextManager#transferContexts} + * which takes a {@link DurableExecutionManager} as an argument rather than holding one). + */ +class OperatorStateManager { + + static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = "messageSequenceNumber"; + static final String PENDING_INPUT_EVENT_STATE_NAME = "pendingInputEvents"; + + private ListState actionTasksKState; + private ListState pendingInputEventsKState; + private ListState currentProcessingKeysOpState; + private ValueState sequenceNumberKState; + private MapState sensoryMemState; + private MapState shortTermMemState; + + OperatorStateManager() {} + + /** + * Registers all keyed-state descriptors against the operator's runtime context. + * + *

Registers: sensory memory map state, short-term memory map state, the per-key message + * sequence-number value state, the per-key list of pending {@link ActionTask}s, and the per-key + * list of buffered {@link Event}s. Called from the operator's {@code open()} method. + * + * @param runtimeContext the operator's runtime context, used to obtain keyed state handles. + */ + void initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext runtimeContext) + throws Exception { + // init sensoryMemState + MapStateDescriptor sensoryMemStateDescriptor = + new MapStateDescriptor<>( + "sensoryMemory", + TypeInformation.of(String.class), + TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); + sensoryMemState = runtimeContext.getMapState(sensoryMemStateDescriptor); + + // init shortTermMemState + MapStateDescriptor shortTermMemStateDescriptor = + new MapStateDescriptor<>( + "shortTermMemory", + TypeInformation.of(String.class), + TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); + shortTermMemState = runtimeContext.getMapState(shortTermMemStateDescriptor); + + // init sequence number state for per key message ordering + sequenceNumberKState = + runtimeContext.getState( + new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class)); + + // init agent processing related state + actionTasksKState = + runtimeContext.getListState( + new ListStateDescriptor<>( + "actionTasks", TypeInformation.of(ActionTask.class))); + pendingInputEventsKState = + runtimeContext.getListState( + new ListStateDescriptor<>( + PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class))); + } + + /** + * Registers operator-level (non-keyed) state. + * + *

Registers the {@code currentProcessingKeys} union-list state. A union-list lets every + * subtask see all keys after a rescale; the operator filters out keys that do not belong to the + * current subtask's key-group range during recovery. Called from the operator's {@code open()} + * method (after {@code super.open()} and after the keyed-state setup). + * + * @param operatorStateBackend the operator state backend used to obtain operator state. + */ + void initializeOperatorStates(OperatorStateBackend operatorStateBackend) throws Exception { + // We use UnionList here to ensure that the task can access all keys after parallelism + // modifications. + // Subsequent steps {@link #tryResumeProcessActionTasks} will then filter out keys that do + // not belong to the key range of current task. + currentProcessingKeysOpState = + operatorStateBackend.getUnionListState( + new ListStateDescriptor<>( + "currentProcessingKeys", TypeInformation.of(Object.class))); + } + + /** + * Advances the per-key message sequence number. + * + *

If the state has no value for the current key, sets it to {@code 0L}. Otherwise increments + * the existing value by one. Must be called under a keyed context. + */ + void initOrIncSequenceNumber() throws Exception { + // Initialize the sequence number state if it does not exist. + Long sequenceNumber = sequenceNumberKState.value(); + if (sequenceNumber == null) { + sequenceNumberKState.update(0L); + } else { + sequenceNumberKState.update(sequenceNumber + 1); + } + } + + long getSequenceNumber() throws Exception { + return sequenceNumberKState.value(); + } + + boolean hasMoreActionTasks() throws Exception { + return listStateNotEmpty(actionTasksKState); + } + + /** + * Removes and returns the next pending {@link ActionTask} for the current key. + * + * @return the next {@link ActionTask}, or {@code null} if the queue for the current key is + * empty. + */ + @Nullable + ActionTask pollNextActionTask() throws Exception { + return pollFromListState(actionTasksKState); + } + + void addActionTask(ActionTask actionTask) throws Exception { + actionTasksKState.add(actionTask); + } + + void addPendingInputEvent(Event event) throws Exception { + pendingInputEventsKState.add(event); + } + + /** + * Removes and returns the next pending input {@link Event} buffered for the current key. + * + * @return the next buffered input {@link Event}, or {@code null} if the buffer for the current + * key is empty. + */ + @Nullable + Event pollNextPendingInputEvent() throws Exception { + return pollFromListState(pendingInputEventsKState); + } + + void addProcessingKey(Object key) throws Exception { + currentProcessingKeysOpState.add(key); + } + + int removeProcessingKey(Object key) throws Exception { + return removeFromListState(currentProcessingKeysOpState, key); + } + + boolean hasProcessingKeys() throws Exception { + return listStateNotEmpty(currentProcessingKeysOpState); + } + + Iterable getProcessingKeys() throws Exception { + return currentProcessingKeysOpState.get(); + } + + MapState getSensoryMemState() { + return sensoryMemState; + } + + MapState getShortTermMemState() { + return shortTermMemState; + } + + KeyGroupRange getCurrentSubtaskKeyGroupRange( + int maxParallelism, + org.apache.flink.api.common.functions.RuntimeContext runtimeContext) { + int parallelism = runtimeContext.getTaskInfo().getNumberOfParallelSubtasks(); + int subtaskIndex = runtimeContext.getTaskInfo().getIndexOfThisSubtask(); + return KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex( + maxParallelism, parallelism, subtaskIndex); + } + + boolean isKeyOwnedByCurrentSubtask( + Object key, int maxParallelism, KeyGroupRange currentSubtaskKeyGroupRange) { + int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism); + return currentSubtaskKeyGroupRange.contains(keyGroup); + } + + /** + * Captures the current per-key sequence numbers across all keys held by the given backend. + * + *

Invoked during checkpoint snapshotting so the caller can later associate the snapshot's + * per-key sequence numbers with a checkpoint id (see {@link + * DurableExecutionManager#recordCheckpointSequenceNumbers}). + * + * @param keyedStateBackend the keyed state backend to scan. + * @return an immutable map snapshot from key to its current sequence number. + */ + @SuppressWarnings("unchecked") + Map snapshotSequenceNumbers(KeyedStateBackend keyedStateBackend) + throws Exception { + HashMap keyToSeqNum = new HashMap<>(); + ((KeyedStateBackend) keyedStateBackend) + .applyToAllKeys( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class), + (key, state) -> keyToSeqNum.put(key, state.value())); + return keyToSeqNum; + } + + /** + * Applies a function to the pending-input-event list state for every key in the backend. + * + *

Used during recovery to scan all keys that hold buffered input events so the operator can + * resume processing them after a restore. + * + * @param keyedStateBackend the keyed state backend to scan. + * @param function the function to apply per (key, list-state) pair. + */ + @SuppressWarnings("unchecked") + void forEachPendingInputEventKey( + KeyedStateBackend keyedStateBackend, + KeyedStateFunction> function) + throws Exception { + ((KeyedStateBackend) keyedStateBackend) + .applyToAllKeys( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ListStateDescriptor<>( + PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)), + function); + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java new file mode 100644 index 000000000..ba765924f --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.PythonFunction; +import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; +import org.apache.flink.agents.runtime.PythonMCPResourceDiscovery; +import org.apache.flink.agents.runtime.ResourceCache; +import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; +import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; +import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; +import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; +import org.apache.flink.agents.runtime.python.utils.JavaResourceAdapter; +import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; +import org.apache.flink.agents.runtime.python.utils.PythonResourceAdapterImpl; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; +import org.apache.flink.python.env.PythonDependencyInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pemja.core.PythonInterpreter; + +import javax.annotation.Nullable; + +import java.util.HashMap; + +/** + * Owns the embedded Python runtime used by {@link ActionExecutionOperator} when an agent plan + * contains Python actions or Python-defined resources. + * + *

Owned state: + * + *

    + *
  • The {@link PythonEnvironmentManager} that prepares dependencies and the Pemja runtime. + *
  • The {@link PythonInterpreter} obtained from that environment. + *
  • The {@link PythonActionExecutor} (when the plan contains Python actions). + *
  • The {@link PythonRunnerContextImpl} consumed by Python actions. + *
  • The Java/Python resource adapters that bridge resource lookups across languages. + *
+ * + *

Lifecycle: instantiated by the operator's {@code open()} (lazy — not in the operator + * constructor), then immediately initialized via {@link #open} in the same call. {@link #open} is a + * no-op when the agent plan contains no Python actions and no Python resources — in that case all + * accessors return {@code null} and {@link #isInitialized()} returns {@code false}. {@link + * #close()} closes the owned resources in the reverse order of creation: {@code + * pythonActionExecutor} → {@code pythonInterpreter} → {@code pythonEnvironmentManager}. + * + *

Design constraint: package-private; no manager-to-manager held references. Other managers + * receive what they need (e.g. the Python runner context, the action executor) via method + * parameters. + */ +class PythonBridgeManager implements AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(PythonBridgeManager.class); + + private PythonEnvironmentManager pythonEnvironmentManager; + private PythonInterpreter pythonInterpreter; + private PythonActionExecutor pythonActionExecutor; + private PythonRunnerContextImpl pythonRunnerContext; + private PythonResourceAdapterImpl pythonResourceAdapter; + private JavaResourceAdapter javaResourceAdapter; + private boolean initialized; + + PythonBridgeManager() { + this.initialized = false; + } + + /** + * Initializes the Python runtime if the agent plan needs it. + * + *

Scans the agent plan for any {@link PythonFunction} action or {@link + * PythonResourceProvider}. If neither is present, this method is a no-op and {@link + * #isInitialized()} stays {@code false}. Otherwise it builds the {@link + * PythonEnvironmentManager}, opens an embedded {@link PythonInterpreter}, constructs the shared + * {@link PythonRunnerContextImpl}, wires the Java/Python resource adapters, and conditionally + * initializes the Python action executor and the Python resource adapter (each only when the + * corresponding component is present in the plan). + * + * @param agentPlan the agent plan describing actions and resources. + * @param resourceCache the resource cache visible to both languages. + * @param executionConfig used to derive Python dependency information. + * @param distributedCache used to resolve distributed Python files. + * @param tmpDirs Flink-managed temp directories made available to Python. + * @param jobId the Flink job id. + * @param metricGroup the agent metric group, exposed to Python via the runner context. + * @param mailboxThreadChecker hook used by the runner context to assert mailbox-thread access. + * @param jobIdentifier the job identifier used to scope Python state. + */ + void open( + AgentPlan agentPlan, + ResourceCache resourceCache, + ExecutionConfig executionConfig, + org.apache.flink.api.common.cache.DistributedCache distributedCache, + String[] tmpDirs, + JobID jobId, + FlinkAgentsMetricGroupImpl metricGroup, + Runnable mailboxThreadChecker, + String jobIdentifier) + throws Exception { + boolean containPythonAction = + agentPlan.getActions().values().stream() + .anyMatch(action -> action.getExec() instanceof PythonFunction); + + boolean containPythonResource = + agentPlan.getResourceProviders().values().stream() + .anyMatch( + resourceProviderMap -> + resourceProviderMap.values().stream() + .anyMatch( + resourceProvider -> + resourceProvider + instanceof + PythonResourceProvider)); + + if (containPythonAction || containPythonResource) { + LOG.debug("Begin initialize PythonEnvironmentManager."); + PythonDependencyInfo dependencyInfo = + PythonDependencyInfo.create( + executionConfig.toConfiguration(), distributedCache); + pythonEnvironmentManager = + new PythonEnvironmentManager( + dependencyInfo, tmpDirs, new HashMap<>(System.getenv()), jobId); + pythonEnvironmentManager.open(); + EmbeddedPythonEnvironment env = pythonEnvironmentManager.createEnvironment(); + pythonInterpreter = env.getInterpreter(); + pythonRunnerContext = + new PythonRunnerContextImpl( + metricGroup, + mailboxThreadChecker, + agentPlan, + resourceCache, + jobIdentifier); + + javaResourceAdapter = + new JavaResourceAdapter( + (name, type) -> { + try { + return resourceCache.getResource(name, type); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + pythonInterpreter); + if (containPythonResource) { + initPythonResourceAdapter(agentPlan, resourceCache); + } + if (containPythonAction) { + initPythonActionExecutor(agentPlan, jobIdentifier); + } + initialized = true; + } + } + + private void initPythonActionExecutor(AgentPlan agentPlan, String jobIdentifier) + throws Exception { + pythonActionExecutor = + new PythonActionExecutor( + pythonInterpreter, + agentPlan, + javaResourceAdapter, + pythonRunnerContext, + jobIdentifier); + pythonActionExecutor.open(); + } + + private void initPythonResourceAdapter(AgentPlan agentPlan, ResourceCache resourceCache) + throws Exception { + pythonResourceAdapter = + new PythonResourceAdapterImpl( + (String anotherName, ResourceType anotherType) -> { + try { + return resourceCache.getResource(anotherName, anotherType); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + pythonInterpreter, + javaResourceAdapter); + pythonResourceAdapter.open(); + PythonMCPResourceDiscovery.discoverPythonMCPResources( + agentPlan.getResourceProviders(), pythonResourceAdapter, resourceCache); + } + + /** + * @return the Python action executor, or {@code null} if the agent plan contains no Python + * actions (or {@link #open} has not yet been called). + */ + @Nullable + PythonActionExecutor getPythonActionExecutor() { + return pythonActionExecutor; + } + + /** + * @return the Python runner context, or {@code null} if no Python runtime was initialized + * because the agent plan has neither Python actions nor Python resources. + */ + @Nullable + PythonRunnerContextImpl getPythonRunnerContext() { + return pythonRunnerContext; + } + + boolean isInitialized() { + return initialized; + } + + @Override + public void close() throws Exception { + if (pythonActionExecutor != null) { + pythonActionExecutor.close(); + } + if (pythonInterpreter != null) { + pythonInterpreter.close(); + } + if (pythonEnvironmentManager != null) { + pythonEnvironmentManager.close(); + } + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java index a45a52d4a..f5608b099 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java @@ -252,12 +252,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - // Use reflection to access the action state store for validation - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); assertThat(actionStateStore).isNotNull(); assertThat(actionStateStore.getKeyedActionStates()).isEmpty(); @@ -313,9 +310,7 @@ void testEventLogBaseDirFromAgentConfig() throws Exception { testHarness.open(); ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - Field eventLoggerField = ActionExecutionOperator.class.getDeclaredField("eventLogger"); - eventLoggerField.setAccessible(true); - Object eventLogger = eventLoggerField.get(operator); + Object eventLogger = operator.getEventRouter().getEventLogger(); assertThat(eventLogger).isInstanceOf(FileEventLogger.class); Field configField = FileEventLogger.class.getDeclaredField("config"); @@ -346,12 +341,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - // Use reflection to access the action state store for validation - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); Long inputValue = 3L; testHarness.processElement(new StreamRecord<>(inputValue)); @@ -421,12 +413,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - // Access the action state store - java.lang.reflect.Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); // Process multiple elements with same key to test state persistence testHarness.processElement(new StreamRecord<>(1L)); @@ -490,12 +479,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(true)), (List>) testHarness.getRecordOutput(); assertThat(recordOutput.size()).isEqualTo(3); - // Access the action state store - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); InMemoryActionStateStore actionStateStore = - (InMemoryActionStateStore) actionStateStoreField.get(operator); + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); assertThat(actionStateStore.getKeyedActionStates()).isEmpty(); } } @@ -514,11 +500,9 @@ agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), ActionExecutionOperator operator = (ActionExecutionOperator) testHarness.getOperator(); - // Access the action state store - Field actionStateStoreField = - ActionExecutionOperator.class.getDeclaredField("actionStateStore"); - actionStateStoreField.setAccessible(true); - actionStateStore = (InMemoryActionStateStore) actionStateStoreField.get(operator); + actionStateStore = + (InMemoryActionStateStore) + operator.getDurableExecutionManager().getActionStateStore(); Long inputValue = 7L; diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManagerTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManagerTest.java new file mode 100644 index 000000000..69b7ecb56 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionTaskContextManagerTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.async.ContinuationContext; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Contract tests for {@link ActionTaskContextManager}. */ +class ActionTaskContextManagerTest { + + @Test + void perTaskMapsAreIsolatedAcrossPutGetRemove() throws Exception { + try (ActionTaskContextManager mgr = new ActionTaskContextManager(1)) { + Action action = TestActions.noopAction(); + ActionTask t1 = new JavaActionTask("k", new InputEvent(1L), action); + ActionTask t2 = new JavaActionTask("k", new InputEvent(2L), action); + + ContinuationContext c1 = new ContinuationContext(); + mgr.putContinuationContext(t1, c1); + mgr.putPythonAwaitableRef(t2, "ref-2"); + + // Cross-task isolation: each map only carries the entry it was given. + assertThat(mgr.getContinuationContext(t1)).isSameAs(c1); + assertThat(mgr.getContinuationContext(t2)).isNull(); + assertThat(mgr.getPythonAwaitableRef(t1)).isNull(); + assertThat(mgr.getPythonAwaitableRef(t2)).isEqualTo("ref-2"); + assertThat(mgr.hasContinuationContext(t1)).isTrue(); + assertThat(mgr.hasContinuationContext(t2)).isFalse(); + + // Remove and re-check + mgr.removeContinuationContext(t1); + mgr.removePythonAwaitableRef(t2); + assertThat(mgr.hasContinuationContext(t1)).isFalse(); + assertThat(mgr.getPythonAwaitableRef(t2)).isNull(); + } + } + + @Test + void createOrGetRunnerContextThrowsWhenPythonContextRequestedButNull() throws Exception { + try (ActionTaskContextManager mgr = new ActionTaskContextManager(1)) { + assertThatThrownBy( + () -> + mgr.createOrGetRunnerContext( + /* isJava */ false, + /* agentPlan */ null, + /* resourceCache */ null, + /* metricGroup */ null, + /* jobIdentifier */ "job", + /* mailboxThreadChecker */ () -> {}, + /* pythonRunnerContext */ null)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("PythonRunnerContextImpl has not been initialized"); + } + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/DurableExecutionManagerTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/DurableExecutionManagerTest.java new file mode 100644 index 000000000..98a8b4091 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/DurableExecutionManagerTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.actionstate.ActionState; +import org.apache.flink.agents.runtime.actionstate.InMemoryActionStateStore; +import org.apache.flink.agents.runtime.context.RunnerContextImpl; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** Contract tests for {@link DurableExecutionManager}. */ +class DurableExecutionManagerTest { + + @Test + void noStoreModeMakesAllMaybeOperationsNoOp() throws Exception { + DurableExecutionManager dem = new DurableExecutionManager(null); + // No ACTION_STATE_STORE_BACKEND set → no default store should be created. + dem.maybeInitActionStateStore(new AgentConfiguration()); + + assertThat(dem.hasDurableStore()).isFalse(); + assertThat(dem.getActionStateStore()).isNull(); + + Action action = TestActions.noopAction(); + Event event = new InputEvent(0L); + + // Every maybe* method must be a silent no-op. + assertThat(dem.maybeGetActionState("k", 0L, action, event)).isNull(); + dem.maybeInitActionState("k", 0L, action, event); + dem.maybePruneState("k", 0L); + dem.notifyCheckpointComplete(1L); + dem.snapshotRecoveryMarker(); + dem.close(); + } + + @Test + void withInjectedStorePersistsTaskResult() throws Exception { + InMemoryActionStateStore store = new InMemoryActionStateStore(false); + DurableExecutionManager dem = new DurableExecutionManager(store); + + assertThat(dem.hasDurableStore()).isTrue(); + assertThat(dem.getActionStateStore()).isSameAs(store); + + Action action = TestActions.noopAction(); + Event event = new InputEvent(42L); + String key = "key-1"; + long seq = 0L; + + // First call seeds an initial ActionState in the store. + dem.maybeInitActionState(key, seq, action, event); + assertThat(store.getKeyedActionStates()).containsKey(key); + assertThat(dem.maybeGetActionState(key, seq, action, event)).isNotNull(); + + // Build a finished task result with one output event; verify persist folds it into state. + Event outEvent = new OutputEvent(99L); + RunnerContextImpl context = mock(RunnerContextImpl.class); + when(context.getSensoryMemoryUpdates()).thenReturn(List.of()); + when(context.getShortTermMemoryUpdates()).thenReturn(List.of()); + + ActionTask.ActionTaskResult finishedResult = mock(ActionTask.ActionTaskResult.class); + when(finishedResult.isFinished()).thenReturn(true); + when(finishedResult.getOutputEvents()).thenReturn(List.of(outEvent)); + + dem.maybePersistTaskResult(key, seq, action, event, context, finishedResult); + + ActionState persisted = dem.maybeGetActionState(key, seq, action, event); + assertThat(persisted).isNotNull(); + assertThat(persisted.getOutputEvents()).contains(outEvent); + assertThat(persisted.isCompleted()).isTrue(); + verify(context).clearDurableExecutionContext(); + + dem.close(); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/EventRouterTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/EventRouterTest.java new file mode 100644 index 000000000..a93ab70f1 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/EventRouterTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.actions.Action; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Contract tests for {@link EventRouter}. */ +class EventRouterTest { + + @Test + void wrapToInputEventReturnsJavaInputEventForJavaInput() { + AgentPlan plan = new AgentPlan(new HashMap<>(), new HashMap<>()); + EventRouter router = new EventRouter<>(plan, /* inputIsJava */ true); + + Event event = router.wrapToInputEvent(42L, /* pythonActionExecutor */ null); + + assertThat(event).isInstanceOf(InputEvent.class); + assertThat(((InputEvent) event).getInput()).isEqualTo(42L); + } + + @Test + void getActionsTriggeredByReturnsActionsForJavaEventClass() throws Exception { + Action action = TestActions.noopAction(); + Map actions = Map.of(action.getName(), action); + Map> byEvent = Map.of(InputEvent.class.getName(), List.of(action)); + AgentPlan plan = new AgentPlan(actions, byEvent); + + EventRouter router = new EventRouter<>(plan, /* inputIsJava */ true); + + List triggered = router.getActionsTriggeredBy(new InputEvent(0L), plan); + + assertThat(triggered).containsExactly(action); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java new file mode 100644 index 000000000..f7226826e --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Contract tests for {@link PythonBridgeManager}. */ +class PythonBridgeManagerTest { + + @Test + void openIsNoOpWhenPlanHasNeitherPythonActionsNorResources() throws Exception { + // Java-only plan: one Java action, no resources. + Action javaAction = TestActions.noopAction(); + Map actions = Map.of(javaAction.getName(), javaAction); + Map> byEvent = Map.of(InputEvent.class.getName(), List.of(javaAction)); + AgentPlan plan = new AgentPlan(actions, byEvent); + + try (PythonBridgeManager bridge = new PythonBridgeManager()) { + bridge.open( + plan, + /* resourceCache */ null, + new ExecutionConfig(), + /* distributedCache */ null, + /* tmpDirs */ new String[] {System.getProperty("java.io.tmpdir")}, + /* jobId */ new JobID(), + /* metricGroup */ null, + /* mailboxThreadChecker */ () -> {}, + /* jobIdentifier */ "job-1"); + + // No-op contract: nothing initialized, no Pemja interpreter created. + assertThat(bridge.isInitialized()).isFalse(); + assertThat(bridge.getPythonActionExecutor()).isNull(); + assertThat(bridge.getPythonRunnerContext()).isNull(); + } + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/TestActions.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/TestActions.java new file mode 100644 index 000000000..f427814a0 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/TestActions.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.agents.plan.actions.Action; + +import java.util.List; + +/** + * Shared helpers for the manager contract tests in this package. + * + *

Provides a minimal {@link Action} backed by a no-op static Java function so individual tests + * do not need to redeclare the boilerplate around {@link JavaFunction#JavaFunction(Class, String, + * Class[])} signature checks. + */ +final class TestActions { + + private TestActions() {} + + /** Returns a minimal noop Java action backed by {@link #noop(InputEvent, RunnerContext)}. */ + static Action noopAction() { + try { + return new Action( + "noop", + new JavaFunction( + TestActions.class, + "noop", + new Class[] {InputEvent.class, RunnerContext.class}), + List.of(InputEvent.class.getName())); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** No-op static method referenced by {@link #noopAction()}. Must be public for reflection. */ + public static void noop(InputEvent event, RunnerContext context) {} +}