diff --git a/invokeai/frontend/web/src/services/events/invocationTracking.test.ts b/invokeai/frontend/web/src/services/events/invocationTracking.test.ts index 68ab114af33..d9d8a5ae627 100644 --- a/invokeai/frontend/web/src/services/events/invocationTracking.test.ts +++ b/invokeai/frontend/web/src/services/events/invocationTracking.test.ts @@ -4,7 +4,6 @@ import { clearCompletedInvocationKeysForQueueItem, hasCompletedInvocationKey, markInvocationAsCompleted, - shouldIgnoreFinishedQueueItemInvocationEvent, } from './invocationTracking'; describe(markInvocationAsCompleted.name, () => { @@ -30,18 +29,3 @@ describe(markInvocationAsCompleted.name, () => { expect(hasCompletedInvocationKey(completedInvocationKeysByItemId, 2, 'prepared-node-1')).toBe(true); }); }); - -describe(shouldIgnoreFinishedQueueItemInvocationEvent.name, () => { - it('ignores late started and progress events for finished queue items', () => { - const finishedQueueItemIds = new Set([1]); - - expect(shouldIgnoreFinishedQueueItemInvocationEvent('invocation_started', finishedQueueItemIds, 1)).toBe(true); - expect(shouldIgnoreFinishedQueueItemInvocationEvent('invocation_progress', finishedQueueItemIds, 1)).toBe(true); - }); - - it('does not ignore late error events for finished queue items', () => { - const finishedQueueItemIds = new Set([1]); - - expect(shouldIgnoreFinishedQueueItemInvocationEvent('invocation_error', finishedQueueItemIds, 1)).toBe(false); - }); -}); diff --git a/invokeai/frontend/web/src/services/events/invocationTracking.ts b/invokeai/frontend/web/src/services/events/invocationTracking.ts index f5955721b81..387126a4174 100644 --- a/invokeai/frontend/web/src/services/events/invocationTracking.ts +++ b/invokeai/frontend/web/src/services/events/invocationTracking.ts @@ -1,11 +1,5 @@ type CompletedInvocationKeysByItemId = Map>; -type FinishedQueueItemIds = { - has: (itemId: number) => boolean; -}; - -type FinishedQueueItemInvocationEventName = 'invocation_error' | 'invocation_progress' | 'invocation_started'; - export const hasCompletedInvocationKey = ( completedInvocationKeysByItemId: CompletedInvocationKeysByItemId, itemId: number, @@ -31,15 +25,3 @@ export const clearCompletedInvocationKeysForQueueItem = ( ) => { completedInvocationKeysByItemId.delete(itemId); }; - -export const shouldIgnoreFinishedQueueItemInvocationEvent = ( - eventName: FinishedQueueItemInvocationEventName, - finishedQueueItemIds: FinishedQueueItemIds, - itemId: number -) => { - if (eventName === 'invocation_error') { - return false; - } - - return finishedQueueItemIds.has(itemId); -}; diff --git a/invokeai/frontend/web/src/services/events/nodeExecutionState.test.ts b/invokeai/frontend/web/src/services/events/nodeExecutionState.test.ts index e868f25da3d..8ce96bbc8a0 100644 --- a/invokeai/frontend/web/src/services/events/nodeExecutionState.test.ts +++ b/invokeai/frontend/web/src/services/events/nodeExecutionState.test.ts @@ -4,6 +4,8 @@ import type { S } from 'services/api/types'; import { describe, expect, it } from 'vitest'; import { + getCompletedInvocationIdsFromCompletedSession, + getNodeExecutionStatesFromCompletedSession, getUpdatedNodeExecutionStateOnInvocationComplete, getUpdatedNodeExecutionStateOnInvocationError, getUpdatedNodeExecutionStateOnInvocationProgress, @@ -278,3 +280,56 @@ describe(getUpdatedNodeExecutionStateOnInvocationError.name, () => { }); }); }); + +describe(getNodeExecutionStatesFromCompletedSession.name, () => { + it('builds completed node execution states from a completed session', () => { + const result = { type: 'integer_output', value: 42 } as unknown as S['GraphExecutionState']['results'][string]; + const states = getNodeExecutionStatesFromCompletedSession({ + source_prepared_mapping: { + 'node-1': ['prepared-node-1'], + }, + results: { + 'prepared-node-1': result, + }, + } as unknown as S['GraphExecutionState']); + + expect(states).toEqual([ + { + nodeId: 'node-1', + status: zNodeStatus.enum.COMPLETED, + progress: null, + progressImage: null, + outputs: [result], + error: null, + }, + ]); + }); + + it('does not create a completed state for source nodes with no results', () => { + const states = getNodeExecutionStatesFromCompletedSession({ + source_prepared_mapping: { + 'node-1': ['prepared-node-1'], + }, + results: {}, + } as unknown as S['GraphExecutionState']); + + expect(states).toEqual([]); + }); +}); + +describe(getCompletedInvocationIdsFromCompletedSession.name, () => { + it('returns prepared invocation ids that have persisted results', () => { + const result = { type: 'integer_output', value: 42 } as unknown as S['GraphExecutionState']['results'][string]; + const invocationIds = getCompletedInvocationIdsFromCompletedSession({ + source_prepared_mapping: { + 'node-1': ['prepared-node-1'], + 'node-2': ['prepared-node-2'], + }, + results: { + 'prepared-node-1': result, + }, + } as unknown as S['GraphExecutionState']); + + expect(invocationIds).toEqual(['prepared-node-1']); + }); +}); diff --git a/invokeai/frontend/web/src/services/events/nodeExecutionState.ts b/invokeai/frontend/web/src/services/events/nodeExecutionState.ts index c15f0467772..dbd1bc0492f 100644 --- a/invokeai/frontend/web/src/services/events/nodeExecutionState.ts +++ b/invokeai/frontend/web/src/services/events/nodeExecutionState.ts @@ -83,3 +83,42 @@ export const getUpdatedNodeExecutionStateOnInvocationError = ( return _nodeExecutionState; }; + +export const getNodeExecutionStatesFromCompletedSession = ( + session: S['SessionQueueItem']['session'] +): NodeExecutionState[] => { + const nodeExecutionStates: NodeExecutionState[] = []; + + for (const [nodeId, preparedNodeIds] of Object.entries(session.source_prepared_mapping)) { + const outputs = preparedNodeIds.flatMap((preparedNodeId) => { + const result = session.results[preparedNodeId]; + return result ? [result] : []; + }); + + if (outputs.length === 0) { + continue; + } + + nodeExecutionStates.push({ + ...getInitialNodeExecutionState(nodeId), + status: zNodeStatus.enum.COMPLETED, + outputs, + }); + } + + return nodeExecutionStates; +}; + +export const getCompletedInvocationIdsFromCompletedSession = (session: S['SessionQueueItem']['session']): string[] => { + const completedInvocationIds: string[] = []; + + for (const preparedNodeIds of Object.values(session.source_prepared_mapping)) { + for (const preparedNodeId of preparedNodeIds) { + if (session.results[preparedNodeId]) { + completedInvocationIds.push(preparedNodeId); + } + } + } + + return completedInvocationIds; +}; diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 1e73abb2027..4d5b5901321 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -2,8 +2,8 @@ import { Flex, Text } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; import { socketConnected } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected'; import type { AppStore } from 'app/store/store'; -import { deepClone } from 'common/util/deepClone'; -import { forEach, isNil, round } from 'es-toolkit/compat'; +import { parseify } from 'common/util/serialize'; +import { isNil, round } from 'es-toolkit/compat'; import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks'; import { allEntitiesDeleted, controlLayerRecalled } from 'features/controlLayers/store/canvasSlice'; import { canvasWorkflowIntegrationProcessingCompleted } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; @@ -26,30 +26,20 @@ import type { } from 'features/controlLayers/store/types'; import { getControlLayerState, getReferenceImageState } from 'features/controlLayers/store/util'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState'; -import { zNodeStatus } from 'features/nodes/types/invocation'; import { modelSelected } from 'features/parameters/store/actions'; import ErrorToastDescription, { getTitle } from 'features/toast/ErrorToastDescription'; import { toast, toastApi } from 'features/toast/toast'; import { t } from 'i18next'; -import { LRUCache } from 'lru-cache'; import { Trans } from 'react-i18next'; import type { ApiTagDescription } from 'services/api'; import { api, LIST_ALL_TAG, LIST_TAG } from 'services/api'; import { imagesApi } from 'services/api/endpoints/images'; import { modelsApi } from 'services/api/endpoints/models'; import { queueApi } from 'services/api/endpoints/queue'; -import { - clearCompletedInvocationKeysForQueueItem, - shouldIgnoreFinishedQueueItemInvocationEvent, -} from 'services/events/invocationTracking'; -import { - getUpdatedNodeExecutionStateOnInvocationError, - getUpdatedNodeExecutionStateOnInvocationProgress, - getUpdatedNodeExecutionStateOnInvocationStarted, -} from 'services/events/nodeExecutionState'; import { buildOnInvocationComplete } from 'services/events/onInvocationComplete'; import { buildOnModelInstallError, DiscordLink, GitHubIssuesLink } from 'services/events/onModelInstallError'; import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; +import { createWorkflowExecutionCoordinator } from 'services/events/workflowExecutionCoordinator'; import type { Socket } from 'socket.io-client'; import type { JsonObject } from 'type-fest'; @@ -72,10 +62,22 @@ const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select(); export const setEventListeners = ({ socket, store, setIsConnected }: SetEventListenersArg) => { const { dispatch, getState } = store; - // We can have race conditions where we receive a progress event for a queue item that has already finished. Easiest - // way to handle this is to keep track of finished queue items in a cache and ignore progress events for those. - const finishedQueueItemIds = new LRUCache({ max: 100 }); const completedInvocationKeysByItemId = new Map>(); + const onInvocationComplete = buildOnInvocationComplete(getState, dispatch, completedInvocationKeysByItemId); + const workflowExecutionCoordinator = createWorkflowExecutionCoordinator({ + clearCanvasWorkflowIntegrationProcessing: () => dispatch(canvasWorkflowIntegrationProcessingCompleted()), + completedInvocationKeysByItemId, + getAllNodeExecutionStates: () => $nodeExecutionStates.get(), + getNodeExecutionState: (nodeId) => $nodeExecutionStates.get()[nodeId], + logReconciliationError: (error, itemId) => { + log.debug({ error: parseify(error) }, `Unable to reconcile workflow queue item ${itemId}`); + }, + onInvocationComplete, + reconcileQueueItem: (itemId) => + dispatch(queueApi.endpoints.getQueueItem.initiate(itemId, { forceRefetch: true, subscribe: false })), + setNodeExecutionState: (nodeId, state) => $nodeExecutionStates.setKey(nodeId, state), + upsertNodeExecutionState: upsertExecutionState, + }); socket.on('connect', () => { log.debug('Connected'); @@ -107,34 +109,24 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis socket.on('disconnect', () => { log.debug('Disconnected'); + workflowExecutionCoordinator.cancelPendingWorkflowReconciliations(); $lastProgressEvent.set(null); $loadingModelsCount.set(0); setIsConnected(false); }); socket.on('invocation_started', (data) => { - if (shouldIgnoreFinishedQueueItemInvocationEvent('invocation_started', finishedQueueItemIds, data.item_id)) { - return; - } const { invocation_source_id, invocation } = data; log.debug({ data } as JsonObject, `Invocation started (${invocation.type}, ${invocation_source_id})`); - const nes = $nodeExecutionStates.get()[invocation_source_id]; - const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationStarted( - nes, - data, - completedInvocationKeysByItemId - ); - if (updatedNodeExecutionState) { - upsertExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState); - } + workflowExecutionCoordinator.onInvocationStarted(data); }); socket.on('invocation_progress', (data) => { - if (shouldIgnoreFinishedQueueItemInvocationEvent('invocation_progress', finishedQueueItemIds, data.item_id)) { + if (!workflowExecutionCoordinator.onInvocationProgress(data)) { log.trace({ data } as JsonObject, `Received event for already-finished queue item ${data.item_id}`); return; } - const { invocation_source_id, invocation, origin, percentage, message } = data; + const { invocation_source_id, invocation, percentage, message } = data; let _message = 'Invocation progress'; if (message) { @@ -148,36 +140,17 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.trace({ data } as JsonObject, _message); $lastProgressEvent.set(data); - - if (origin === 'workflows') { - const nes = $nodeExecutionStates.get()[invocation_source_id]; - const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationProgress( - nes, - data, - completedInvocationKeysByItemId - ); - if (updatedNodeExecutionState) { - upsertExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState); - } - } }); socket.on('invocation_error', (data) => { const { invocation_source_id, invocation } = data; log.error({ data } as JsonObject, `Invocation error (${invocation.type}, ${invocation_source_id})`); - const nes = $nodeExecutionStates.get()[invocation_source_id]; - const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationError(nes, data); - if (updatedNodeExecutionState) { - upsertExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState); - } - // Clear canvas workflow integration processing state on error - if (data.origin === 'canvas_workflow_integration') { - dispatch(canvasWorkflowIntegrationProcessingCompleted()); - } + workflowExecutionCoordinator.onInvocationError(data); }); - const onInvocationComplete = buildOnInvocationComplete(getState, dispatch, completedInvocationKeysByItemId); - socket.on('invocation_complete', onInvocationComplete); + socket.on('invocation_complete', (data) => { + workflowExecutionCoordinator.onInvocationComplete(data); + }); socket.on('model_load_started', (data) => { const { config, submodel_type } = data; @@ -389,8 +362,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis }); socket.on('queue_item_status_changed', (data) => { - if (finishedQueueItemIds.has(data.item_id)) { - log.trace({ data }, `Received event for already-finished queue item ${data.item_id}`); + if (!workflowExecutionCoordinator.onQueueItemStatusChanged(data)) { return; } @@ -463,22 +435,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis } dispatch(queueApi.util.invalidateTags(tagsToInvalidate)); - if (status === 'in_progress') { - forEach($nodeExecutionStates.get(), (nes) => { - if (!nes) { - return; - } - const clone = deepClone(nes); - clone.status = zNodeStatus.enum.PENDING; - clone.error = null; - clone.progress = null; - clone.progressImage = null; - clone.outputs = []; - $nodeExecutionStates.setKey(clone.nodeId, clone); - }); - } else if (status === 'completed' || status === 'failed' || status === 'canceled') { - finishedQueueItemIds.set(item_id, true); - clearCompletedInvocationKeysForQueueItem(completedInvocationKeysByItemId, item_id); + if (status === 'completed' || status === 'failed' || status === 'canceled') { if (status === 'failed' && error_type) { toast({ id: `INVOCATION_ERROR_${error_type}`, diff --git a/invokeai/frontend/web/src/services/events/workflowExecutionCoordinator.test.ts b/invokeai/frontend/web/src/services/events/workflowExecutionCoordinator.test.ts new file mode 100644 index 00000000000..fd0524ece1e --- /dev/null +++ b/invokeai/frontend/web/src/services/events/workflowExecutionCoordinator.test.ts @@ -0,0 +1,280 @@ +import type { NodeExecutionState } from 'features/nodes/types/invocation'; +import { zNodeStatus } from 'features/nodes/types/invocation'; +import type { S } from 'services/api/types'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { createWorkflowExecutionCoordinator } from './workflowExecutionCoordinator'; + +const createDeferredQueueItemRequest = () => { + let resolve!: (value: S['SessionQueueItem']) => void; + let reject!: (reason?: unknown) => void; + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + return { + abort: vi.fn(), + reject, + resolve, + unsubscribe: vi.fn(), + unwrap: () => promise, + }; +}; + +const buildQueueStatusEvent = ( + overrides: Partial +): S['QueueItemStatusChangedEvent'] => + ({ + queue_id: 'default', + item_id: 1, + batch_id: 'batch-1', + origin: 'workflows', + destination: 'gallery', + status: 'completed', + status_sequence: 1, + batch_status: { + batch_id: 'batch-1', + queue_id: 'default', + pending: 0, + in_progress: 0, + completed: 1, + failed: 0, + canceled: 0, + total: 1, + }, + error_type: null, + error_message: null, + error_traceback: null, + created_at: '2026-01-01T00:00:00Z', + updated_at: '2026-01-01T00:00:00Z', + started_at: '2026-01-01T00:00:00Z', + completed_at: '2026-01-01T00:00:00Z', + ...overrides, + }) as S['QueueItemStatusChangedEvent']; + +const buildInvocationStartedEvent = ( + overrides: Partial = {} +): S['InvocationStartedEvent'] => + ({ + queue_id: 'default', + item_id: 2, + batch_id: 'batch-2', + origin: 'workflows', + destination: 'gallery', + user_id: 'user-1', + session_id: 'session-2', + invocation_source_id: 'node-1', + invocation: { + id: 'prepared-node-1', + type: 'test_node', + }, + ...overrides, + }) as S['InvocationStartedEvent']; + +const buildInvocationCompleteEvent = ( + overrides: Partial = {} +): S['InvocationCompleteEvent'] => + ({ + queue_id: 'default', + item_id: 1, + batch_id: 'batch-1', + origin: 'workflows', + destination: 'gallery', + user_id: 'user-1', + session_id: 'session-1', + invocation_source_id: 'node-1', + invocation: { + id: 'prepared-node-1', + type: 'test_node', + }, + result: { + type: 'image_output', + image: { image_name: 'image.png' }, + width: 512, + height: 512, + }, + ...overrides, + }) as S['InvocationCompleteEvent']; + +const buildInvocationErrorEvent = (overrides: Partial = {}): S['InvocationErrorEvent'] => + ({ + queue_id: 'default', + item_id: 1, + batch_id: 'batch-1', + origin: 'workflows', + destination: 'gallery', + user_id: 'user-1', + session_id: 'session-1', + invocation_source_id: 'node-1', + invocation: { + id: 'prepared-node-1', + type: 'test_node', + }, + error_type: 'TestError', + error_message: 'boom', + error_traceback: 'traceback', + ...overrides, + }) as S['InvocationErrorEvent']; + +const buildQueueItem = (status: S['SessionQueueItem']['status']): S['SessionQueueItem'] => + ({ + item_id: 1, + queue_id: 'default', + batch_id: 'batch-1', + session_id: 'session-1', + origin: 'workflows', + destination: 'gallery', + status, + priority: 0, + created_at: '2026-01-01T00:00:00Z', + updated_at: '2026-01-01T00:00:00Z', + session: { + source_prepared_mapping: { + 'node-1': ['prepared-node-1'], + }, + results: { + 'prepared-node-1': { + type: 'image_output', + image: { image_name: 'old-image.png' }, + width: 512, + height: 512, + }, + }, + }, + }) as unknown as S['SessionQueueItem']; + +const createCoordinatorHarness = () => { + const completedInvocationKeysByItemId = new Map>(); + const nodeExecutionStates: Record = {}; + const onInvocationComplete = vi.fn(); + const clearCanvasWorkflowIntegrationProcessing = vi.fn(); + const logReconciliationError = vi.fn(); + const queueItemRequests = new Map>(); + + const coordinator = createWorkflowExecutionCoordinator({ + clearCanvasWorkflowIntegrationProcessing, + completedInvocationKeysByItemId, + getAllNodeExecutionStates: () => nodeExecutionStates, + getNodeExecutionState: (nodeId) => nodeExecutionStates[nodeId], + logReconciliationError, + onInvocationComplete, + reconcileQueueItem: (itemId) => { + const req = createDeferredQueueItemRequest(); + queueItemRequests.set(itemId, req); + return req; + }, + setNodeExecutionState: (nodeId, state) => { + nodeExecutionStates[nodeId] = state; + }, + upsertNodeExecutionState: (nodeId, state) => { + nodeExecutionStates[nodeId] = { ...nodeExecutionStates[nodeId], ...state }; + }, + }); + + return { + clearCanvasWorkflowIntegrationProcessing, + coordinator, + nodeExecutionStates, + onInvocationComplete, + queueItemRequests, + }; +}; + +describe(createWorkflowExecutionCoordinator.name, () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('does not let stale reconciliation overwrite a newer in-progress workflow item', async () => { + const { coordinator, nodeExecutionStates, queueItemRequests } = createCoordinatorHarness(); + + coordinator.onQueueItemStatusChanged( + buildQueueStatusEvent({ item_id: 1, status: 'completed', origin: 'workflows' }) + ); + coordinator.onQueueItemStatusChanged( + buildQueueStatusEvent({ item_id: 2, status: 'in_progress', origin: 'workflows' }) + ); + coordinator.onInvocationStarted(buildInvocationStartedEvent({ item_id: 2 })); + + expect(nodeExecutionStates['node-1']?.status).toBe(zNodeStatus.enum.IN_PROGRESS); + + queueItemRequests.get(1)?.resolve(buildQueueItem('completed')); + await Promise.resolve(); + await Promise.resolve(); + + expect(nodeExecutionStates['node-1']?.status).toBe(zNodeStatus.enum.IN_PROGRESS); + expect(nodeExecutionStates['node-1']?.outputs).toEqual([]); + }); + + it('does not let a late invocation_complete from an old workflow item overwrite the active workflow item', () => { + const { coordinator, nodeExecutionStates, onInvocationComplete } = createCoordinatorHarness(); + + coordinator.onQueueItemStatusChanged( + buildQueueStatusEvent({ item_id: 1, status: 'completed', origin: 'workflows' }) + ); + coordinator.onQueueItemStatusChanged( + buildQueueStatusEvent({ item_id: 2, status: 'in_progress', origin: 'workflows' }) + ); + coordinator.onInvocationStarted(buildInvocationStartedEvent({ item_id: 2 })); + coordinator.onInvocationComplete(buildInvocationCompleteEvent({ item_id: 1 })); + + expect(onInvocationComplete).toHaveBeenCalledTimes(1); + expect(nodeExecutionStates['node-1']?.status).toBe(zNodeStatus.enum.IN_PROGRESS); + expect(nodeExecutionStates['node-1']?.outputs).toEqual([]); + }); + + it('still runs invocation_complete side effects after a workflow item failed', () => { + const { coordinator, onInvocationComplete } = createCoordinatorHarness(); + + coordinator.onQueueItemStatusChanged(buildQueueStatusEvent({ item_id: 1, status: 'failed', origin: 'workflows' })); + coordinator.onInvocationComplete(buildInvocationCompleteEvent({ item_id: 1 })); + + expect(onInvocationComplete).toHaveBeenCalledTimes(1); + }); + + it('reconciles completed sibling outputs from failed workflow queue items', async () => { + const { coordinator, nodeExecutionStates, queueItemRequests } = createCoordinatorHarness(); + + coordinator.onQueueItemStatusChanged(buildQueueStatusEvent({ item_id: 1, status: 'failed', origin: 'workflows' })); + + queueItemRequests.get(1)?.resolve(buildQueueItem('failed')); + await Promise.resolve(); + await Promise.resolve(); + + expect(nodeExecutionStates['node-1']?.status).toBe(zNodeStatus.enum.COMPLETED); + expect(nodeExecutionStates['node-1']?.outputs).toHaveLength(1); + }); + + it('ignores duplicate terminal queue events', () => { + const { coordinator, queueItemRequests } = createCoordinatorHarness(); + + expect( + coordinator.onQueueItemStatusChanged( + buildQueueStatusEvent({ item_id: 1, status: 'completed', origin: 'workflows' }) + ) + ).toBe(true); + expect( + coordinator.onQueueItemStatusChanged( + buildQueueStatusEvent({ item_id: 1, status: 'completed', origin: 'workflows' }) + ) + ).toBe(false); + + expect(queueItemRequests.size).toBe(1); + }); + + it('still clears canvas workflow integration processing on late invocation errors', () => { + const { clearCanvasWorkflowIntegrationProcessing, coordinator } = createCoordinatorHarness(); + + coordinator.onQueueItemStatusChanged( + buildQueueStatusEvent({ item_id: 1, status: 'canceled', origin: 'workflows' }) + ); + coordinator.onInvocationError( + buildInvocationErrorEvent({ + item_id: 1, + origin: 'canvas_workflow_integration', + }) + ); + + expect(clearCanvasWorkflowIntegrationProcessing).toHaveBeenCalledTimes(1); + }); +}); diff --git a/invokeai/frontend/web/src/services/events/workflowExecutionCoordinator.ts b/invokeai/frontend/web/src/services/events/workflowExecutionCoordinator.ts new file mode 100644 index 00000000000..2b97777453e --- /dev/null +++ b/invokeai/frontend/web/src/services/events/workflowExecutionCoordinator.ts @@ -0,0 +1,249 @@ +import { deepClone } from 'common/util/deepClone'; +import type { NodeExecutionState } from 'features/nodes/types/invocation'; +import { zNodeStatus } from 'features/nodes/types/invocation'; +import { LRUCache } from 'lru-cache'; +import type { S } from 'services/api/types'; +import { + clearCompletedInvocationKeysForQueueItem, + markInvocationAsCompleted, +} from 'services/events/invocationTracking'; +import { + getCompletedInvocationIdsFromCompletedSession, + getNodeExecutionStatesFromCompletedSession, + getUpdatedNodeExecutionStateOnInvocationError, + getUpdatedNodeExecutionStateOnInvocationProgress, + getUpdatedNodeExecutionStateOnInvocationStarted, +} from 'services/events/nodeExecutionState'; +import { + createWorkflowExecutionState, + transitionWorkflowExecutionState, + type WorkflowExecutionState, +} from 'services/events/workflowExecutionState'; + +type TerminalQueueStatus = Extract; + +type ReconciliationRequest = { + abort?: () => void; + unsubscribe?: () => void; + unwrap: () => Promise; +}; + +type WorkflowExecutionCoordinatorDeps = { + clearCanvasWorkflowIntegrationProcessing: () => void; + completedInvocationKeysByItemId: Map>; + getAllNodeExecutionStates: () => Record; + getNodeExecutionState: (nodeId: string) => NodeExecutionState | undefined; + logReconciliationError: (error: unknown, itemId: number) => void; + onInvocationComplete: (data: S['InvocationCompleteEvent']) => void; + reconcileQueueItem: (itemId: number) => ReconciliationRequest; + setNodeExecutionState: (nodeId: string, state: NodeExecutionState) => void; + upsertNodeExecutionState: (nodeId: string, state: NodeExecutionState) => void; +}; + +export const createWorkflowExecutionCoordinator = (deps: WorkflowExecutionCoordinatorDeps) => { + const workflowExecutionStates = new LRUCache({ max: 100 }); + const pendingWorkflowReconciliationRequests = new Map(); + let activeWorkflowQueueItemId: number | null = null; + + const transitionWorkflowEvent = ( + itemId: number, + event: Parameters[1] + ): boolean => { + const state = workflowExecutionStates.get(itemId) ?? createWorkflowExecutionState(); + const transition = transitionWorkflowExecutionState(state, event); + workflowExecutionStates.set(itemId, transition.state); + return transition.shouldApply; + }; + + const cleanupWorkflowExecutionState = (itemId: number) => { + const req = pendingWorkflowReconciliationRequests.get(itemId); + req?.abort?.(); + req?.unsubscribe?.(); + pendingWorkflowReconciliationRequests.delete(itemId); + workflowExecutionStates.delete(itemId); + clearCompletedInvocationKeysForQueueItem(deps.completedInvocationKeysByItemId, itemId); + }; + + const cancelPendingWorkflowReconciliations = () => { + for (const req of pendingWorkflowReconciliationRequests.values()) { + req.abort?.(); + req.unsubscribe?.(); + } + pendingWorkflowReconciliationRequests.clear(); + }; + + const reconcileWorkflowQueueItemResults = (itemId: number, status: TerminalQueueStatus) => { + const req = deps.reconcileQueueItem(itemId); + pendingWorkflowReconciliationRequests.set(itemId, req); + req + .unwrap() + .then((queueItem) => { + if (activeWorkflowQueueItemId !== itemId || queueItem.status !== status) { + return; + } + + const completedInvocationIds = getCompletedInvocationIdsFromCompletedSession(queueItem.session); + transitionWorkflowEvent(itemId, { + type: 'session_results_reconciled', + itemId, + status, + completedInvocationIds, + }); + for (const invocationId of completedInvocationIds) { + markInvocationAsCompleted(deps.completedInvocationKeysByItemId, itemId, invocationId); + } + for (const nodeExecutionState of getNodeExecutionStatesFromCompletedSession(queueItem.session)) { + deps.upsertNodeExecutionState(nodeExecutionState.nodeId, nodeExecutionState); + } + }) + .catch((error) => { + deps.logReconciliationError(error, itemId); + }) + .finally(() => { + pendingWorkflowReconciliationRequests.delete(itemId); + req.unsubscribe?.(); + }); + }; + + const onInvocationStarted = (data: S['InvocationStartedEvent']) => { + if ( + !transitionWorkflowEvent(data.item_id, { + type: 'invocation_started', + itemId: data.item_id, + invocationId: data.invocation.id, + }) + ) { + return; + } + + const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationStarted( + deps.getNodeExecutionState(data.invocation_source_id), + data, + deps.completedInvocationKeysByItemId + ); + if (updatedNodeExecutionState) { + deps.upsertNodeExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState); + } + }; + + const onInvocationProgress = (data: S['InvocationProgressEvent']) => { + if ( + !transitionWorkflowEvent(data.item_id, { + type: 'invocation_progress', + itemId: data.item_id, + invocationId: data.invocation.id, + }) + ) { + return false; + } + + if (data.origin === 'workflows') { + const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationProgress( + deps.getNodeExecutionState(data.invocation_source_id), + data, + deps.completedInvocationKeysByItemId + ); + if (updatedNodeExecutionState) { + deps.upsertNodeExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState); + } + } + + return true; + }; + + const onInvocationError = (data: S['InvocationErrorEvent']) => { + if ( + !transitionWorkflowEvent(data.item_id, { + type: 'invocation_error', + itemId: data.item_id, + invocationId: data.invocation.id, + }) + ) { + if (data.origin === 'canvas_workflow_integration') { + deps.clearCanvasWorkflowIntegrationProcessing(); + } + return; + } + + const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationError( + deps.getNodeExecutionState(data.invocation_source_id), + data + ); + if (updatedNodeExecutionState) { + deps.upsertNodeExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState); + } + if (data.origin === 'canvas_workflow_integration') { + deps.clearCanvasWorkflowIntegrationProcessing(); + } + }; + + const onInvocationComplete = (data: S['InvocationCompleteEvent']) => { + if ( + data.origin === 'workflows' && + activeWorkflowQueueItemId !== null && + activeWorkflowQueueItemId !== data.item_id + ) { + markInvocationAsCompleted(deps.completedInvocationKeysByItemId, data.item_id, data.invocation.id); + deps.onInvocationComplete(data); + return; + } + + transitionWorkflowEvent(data.item_id, { + type: 'invocation_complete', + itemId: data.item_id, + invocationId: data.invocation.id, + }); + deps.onInvocationComplete(data); + }; + + const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => { + if ( + !transitionWorkflowEvent(data.item_id, { + type: 'queue_item_status_changed', + itemId: data.item_id, + status: data.status, + }) + ) { + return false; + } + + if (data.origin === 'workflows') { + if (activeWorkflowQueueItemId !== null && activeWorkflowQueueItemId !== data.item_id) { + cleanupWorkflowExecutionState(activeWorkflowQueueItemId); + } + activeWorkflowQueueItemId = data.item_id; + } + + if (data.status === 'in_progress') { + for (const nes of Object.values(deps.getAllNodeExecutionStates())) { + if (!nes) { + continue; + } + const clone = deepClone(nes); + clone.status = zNodeStatus.enum.PENDING; + clone.error = null; + clone.progress = null; + clone.progressImage = null; + clone.outputs = []; + deps.setNodeExecutionState(clone.nodeId, clone); + } + } else if (data.status === 'completed' || data.status === 'failed' || data.status === 'canceled') { + if (data.origin === 'workflows') { + reconcileWorkflowQueueItemResults(data.item_id, data.status); + } else { + cleanupWorkflowExecutionState(data.item_id); + } + } + + return true; + }; + + return { + cancelPendingWorkflowReconciliations, + onInvocationComplete, + onInvocationError, + onInvocationProgress, + onInvocationStarted, + onQueueItemStatusChanged, + }; +}; diff --git a/invokeai/frontend/web/src/services/events/workflowExecutionState.test.ts b/invokeai/frontend/web/src/services/events/workflowExecutionState.test.ts new file mode 100644 index 00000000000..2e99176db8c --- /dev/null +++ b/invokeai/frontend/web/src/services/events/workflowExecutionState.test.ts @@ -0,0 +1,180 @@ +import { describe, expect, it } from 'vitest'; + +import { createWorkflowExecutionState, transitionWorkflowExecutionState } from './workflowExecutionState'; + +describe(transitionWorkflowExecutionState.name, () => { + it('allows invocation completion after the queue item has already completed', () => { + let state = createWorkflowExecutionState(); + + const queueTransition = transitionWorkflowExecutionState(state, { + type: 'queue_item_status_changed', + itemId: 1, + status: 'completed', + }); + + expect(queueTransition.shouldApply).toBe(true); + state = queueTransition.state; + + const completionTransition = transitionWorkflowExecutionState(state, { + type: 'invocation_complete', + itemId: 1, + invocationId: 'prepared-node-1', + }); + + expect(completionTransition.shouldApply).toBe(true); + expect(completionTransition.state.invocations['prepared-node-1']).toBe('completed'); + }); + + it('ignores late progress and error events after an invocation completed', () => { + let state = createWorkflowExecutionState(); + + const completionTransition = transitionWorkflowExecutionState(state, { + type: 'invocation_complete', + itemId: 1, + invocationId: 'prepared-node-1', + }); + + expect(completionTransition.shouldApply).toBe(true); + state = completionTransition.state; + + expect( + transitionWorkflowExecutionState(state, { + type: 'invocation_progress', + itemId: 1, + invocationId: 'prepared-node-1', + }).shouldApply + ).toBe(false); + + expect( + transitionWorkflowExecutionState(state, { + type: 'invocation_error', + itemId: 1, + invocationId: 'prepared-node-1', + }).shouldApply + ).toBe(false); + }); + + it('ignores stale non-terminal queue status after a terminal queue status', () => { + let state = createWorkflowExecutionState(); + + const completedTransition = transitionWorkflowExecutionState(state, { + type: 'queue_item_status_changed', + itemId: 1, + status: 'completed', + }); + + expect(completedTransition.shouldApply).toBe(true); + state = completedTransition.state; + + const staleTransition = transitionWorkflowExecutionState(state, { + type: 'queue_item_status_changed', + itemId: 1, + status: 'in_progress', + }); + + expect(staleTransition.shouldApply).toBe(false); + expect(staleTransition.state.queueStatus).toBe('completed'); + }); + + it('allows completed sibling invocations after a failed queue item', () => { + let state = createWorkflowExecutionState(); + + const failedTransition = transitionWorkflowExecutionState(state, { + type: 'queue_item_status_changed', + itemId: 1, + status: 'failed', + }); + + expect(failedTransition.shouldApply).toBe(true); + state = failedTransition.state; + + const lateCompletionTransition = transitionWorkflowExecutionState(state, { + type: 'invocation_complete', + itemId: 1, + invocationId: 'prepared-node-1', + }); + + expect(lateCompletionTransition.shouldApply).toBe(true); + expect(lateCompletionTransition.state.queueStatus).toBe('failed'); + expect(lateCompletionTransition.state.invocations['prepared-node-1']).toBe('completed'); + }); + + it('treats reconciled completed invocations as terminal after a matching terminal queue transition', () => { + let state = createWorkflowExecutionState(); + + const queueTransition = transitionWorkflowExecutionState(state, { + type: 'queue_item_status_changed', + itemId: 1, + status: 'completed', + }); + + expect(queueTransition.shouldApply).toBe(true); + state = queueTransition.state; + + const reconciliationTransition = transitionWorkflowExecutionState(state, { + type: 'session_results_reconciled', + itemId: 1, + status: 'completed', + completedInvocationIds: ['prepared-node-1'], + }); + + expect(reconciliationTransition.shouldApply).toBe(true); + state = reconciliationTransition.state; + expect(state.queueStatus).toBe('completed'); + expect(state.invocations['prepared-node-1']).toBe('completed'); + + const lateCompletionTransition = transitionWorkflowExecutionState(state, { + type: 'invocation_complete', + itemId: 1, + invocationId: 'prepared-node-1', + }); + + expect(lateCompletionTransition.shouldApply).toBe(false); + }); + + it('ignores duplicate terminal queue events', () => { + let state = createWorkflowExecutionState(); + + const completedTransition = transitionWorkflowExecutionState(state, { + type: 'queue_item_status_changed', + itemId: 1, + status: 'completed', + }); + + expect(completedTransition.shouldApply).toBe(true); + state = completedTransition.state; + + const duplicateTransition = transitionWorkflowExecutionState(state, { + type: 'queue_item_status_changed', + itemId: 1, + status: 'completed', + }); + + expect(duplicateTransition.shouldApply).toBe(false); + expect(duplicateTransition.state.queueStatus).toBe('completed'); + }); + + it('does not reconcile session results unless the terminal status matches', () => { + let state = createWorkflowExecutionState(); + + const failedTransition = transitionWorkflowExecutionState(state, { + type: 'queue_item_status_changed', + itemId: 1, + status: 'failed', + }); + + expect(failedTransition.shouldApply).toBe(true); + state = failedTransition.state; + + const staleReconciliationTransition = transitionWorkflowExecutionState(state, { + type: 'session_results_reconciled', + itemId: 1, + status: 'completed', + completedInvocationIds: ['prepared-node-1'], + }); + + expect(staleReconciliationTransition.shouldApply).toBe(false); + expect(staleReconciliationTransition.state.queueStatus).toBe('failed'); + expect(staleReconciliationTransition.state.invocations['prepared-node-1']).toBeUndefined(); + }); +}); diff --git a/invokeai/frontend/web/src/services/events/workflowExecutionState.ts b/invokeai/frontend/web/src/services/events/workflowExecutionState.ts new file mode 100644 index 00000000000..bb571447ef0 --- /dev/null +++ b/invokeai/frontend/web/src/services/events/workflowExecutionState.ts @@ -0,0 +1,102 @@ +import type { S } from 'services/api/types'; + +type QueueStatus = NonNullable; +type InvocationStatus = 'in_progress' | 'completed' | 'failed'; + +type WorkflowExecutionEvent = + | { + type: 'queue_item_status_changed'; + itemId: number; + status: QueueStatus; + } + | { + type: 'session_results_reconciled'; + itemId: number; + status: Extract; + completedInvocationIds: string[]; + } + | { + type: 'invocation_started' | 'invocation_progress' | 'invocation_complete' | 'invocation_error'; + itemId: number; + invocationId: string; + }; + +export type WorkflowExecutionState = { + itemId: number | null; + queueStatus: QueueStatus | null; + invocations: Record; +}; + +type WorkflowExecutionTransition = { + state: WorkflowExecutionState; + shouldApply: boolean; +}; + +const TERMINAL_QUEUE_STATUSES = new Set(['completed', 'failed', 'canceled']); +const TERMINAL_INVOCATION_STATUSES = new Set(['completed', 'failed']); + +const isTerminalQueueStatus = (status: QueueStatus | null) => status !== null && TERMINAL_QUEUE_STATUSES.has(status); + +export const createWorkflowExecutionState = (): WorkflowExecutionState => ({ + itemId: null, + queueStatus: null, + invocations: {}, +}); + +export const transitionWorkflowExecutionState = ( + state: WorkflowExecutionState, + event: WorkflowExecutionEvent +): WorkflowExecutionTransition => { + const nextState: WorkflowExecutionState = { + itemId: state.itemId ?? event.itemId, + queueStatus: state.queueStatus, + invocations: { ...state.invocations }, + }; + + if (event.type === 'queue_item_status_changed') { + if (isTerminalQueueStatus(state.queueStatus)) { + return { state, shouldApply: false }; + } + + nextState.queueStatus = event.status; + return { state: nextState, shouldApply: true }; + } + + if (event.type === 'session_results_reconciled') { + if (state.queueStatus !== event.status) { + return { state, shouldApply: false }; + } + + for (const invocationId of event.completedInvocationIds) { + nextState.invocations[invocationId] = 'completed'; + } + return { state: nextState, shouldApply: true }; + } + + const invocationStatus = state.invocations[event.invocationId]; + if (invocationStatus && TERMINAL_INVOCATION_STATUSES.has(invocationStatus)) { + return { state, shouldApply: false }; + } + + if (event.type === 'invocation_started' || event.type === 'invocation_progress') { + if (isTerminalQueueStatus(state.queueStatus)) { + return { state, shouldApply: false }; + } + nextState.invocations[event.invocationId] = 'in_progress'; + return { state: nextState, shouldApply: true }; + } + + if (event.type === 'invocation_error') { + if (state.queueStatus === 'completed' || state.queueStatus === 'canceled') { + return { state, shouldApply: false }; + } + nextState.invocations[event.invocationId] = 'failed'; + return { state: nextState, shouldApply: true }; + } + + if (state.queueStatus === 'canceled') { + return { state, shouldApply: false }; + } + nextState.invocations[event.invocationId] = 'completed'; + return { state: nextState, shouldApply: true }; +};