diff --git a/.ai/ARCHITECTURE.md b/.ai/ARCHITECTURE.md index afff17394a..db40397782 100644 --- a/.ai/ARCHITECTURE.md +++ b/.ai/ARCHITECTURE.md @@ -723,6 +723,11 @@ Special handling for Colab: - `background_callback_manager` - DiskcacheManager or CeleryManager - `on_error` - Global callback error handler +**WebSocket Callbacks:** +- `websocket_callbacks` - Enable WebSocket for all callbacks (default: `False`). Requires FastAPI backend. +- `websocket_allowed_origins` - List of allowed origins for WebSocket connections +- `websocket_inactivity_timeout` - Disconnect WebSocket after inactivity period in ms (default: `300000` = 5 minutes). Set to `0` to disable. + ### app.run() Parameters - `host` - Server IP (default: `"127.0.0.1"`, env: `HOST`) @@ -861,6 +866,177 @@ async def async_background(n_clicks): Both DiskcacheManager and CeleryManager support async functions via `asyncio.run()`. +## WebSocket Callbacks + +WebSocket callbacks use a persistent WebSocket connection instead of HTTP POST for callback execution. This reduces latency and connection overhead for applications with frequent callbacks. + +### Requirements + +- **FastAPI backend required**: WebSocket callbacks only work with FastAPI +- **SharedWorker support**: Modern browsers (not IE) + +### Usage + +**Enable globally for all callbacks:** +```python +from fastapi import FastAPI +from dash import Dash + +server = FastAPI() +app = Dash(__name__, server=server, websocket_callbacks=True) +``` + +**Enable per-callback:** +```python +@app.callback( + Output('output', 'children'), + Input('input', 'value'), + websocket=True # Use WebSocket for this callback only +) +def update(value): + return f"Value: {value}" +``` + +### Configuration + +```python +app = Dash( + __name__, + server=server, + websocket_callbacks=True, + websocket_inactivity_timeout=300000, # 5 minutes (default) + websocket_allowed_origins=['https://example.com'], +) +``` + +- **`websocket_callbacks`** - Enable WebSocket for all callbacks (default: `False`) +- **`websocket_inactivity_timeout`** - Close WebSocket after period of inactivity in milliseconds (default: `300000` = 5 minutes). Heartbeats do not count as activity. Set to `0` to disable timeout. Connection automatically reconnects when needed. +- **`websocket_allowed_origins`** - List of allowed origins for WebSocket connections (security) + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ Browser Tab 1 Browser Tab 2 │ +│ ┌─────────────┐ ┌─────────────┐ │ +│ │ Renderer │ │ Renderer │ │ +│ └──────┬──────┘ └──────┬──────┘ │ +│ │ postMessage │ postMessage │ +│ └────────────┬───────────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────┐ │ +│ │ SharedWorker │ (one per origin) │ +│ │ dash-ws-worker │ │ +│ └──────────┬──────────┘ │ +└────────────────────│────────────────────────────────────────────────────┘ + │ WebSocket + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Server (FastAPI) │ +│ WebSocket Endpoint: /_dash-ws-callback │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +**Connection & Reconnection Flow:** +``` +Renderer SharedWorker Server + │ │ │ + │──[CONNECT]──────────────────>│ │ + │ │──[WebSocket Connect]──>│ + │<─[CONNECTED]─────────────────│<─[Connected]───────────│ + │ │ │ + │──[CALLBACK_REQUEST]─────────>│──[callback request]───>│ + │<─[CALLBACK_RESPONSE]─────────│<─[callback response]───│ + │ │ │ + │ (inactivity) │ (heartbeat check) │ + │ │──[close 4001]─────────>│ + │<─[DISCONNECTED]──────────────│ │ + │ │ │ + │──[CALLBACK_REQUEST]─────────>│──[reconnect + send]───>│ + │<─[CALLBACK_RESPONSE]─────────│<─[response]────────────│ +``` + +- **SharedWorker**: Single WebSocket connection shared across browser tabs +- **Heartbeat**: Periodic ping/pong to detect dead connections (30s interval) +- **Inactivity timeout**: Closes connection after no actual callback activity (not heartbeats) +- **Auto-reconnect**: Reconnects automatically when a callback is triggered after timeout + +### Long-Running Callbacks with set_props/get_props + +WebSocket callbacks can stream updates to the client during execution using `set_props()` and read current component values using `ctx.get_websocket()`: + +```python +import asyncio +from dash import callback, Output, Input, set_props, ctx + +@callback( + Output('result', 'children'), + Input('start-btn', 'n_clicks'), + prevent_initial_call=True +) +async def long_running_task(n_clicks): + ws = ctx.get_websocket() + if not ws: + return "WebSocket not available" + + # Stream progress updates to the client + for i in range(100): + await asyncio.sleep(0.1) + set_props('progress-bar', {'value': i + 1}) + set_props('status', {'children': f'Processing step {i + 1}/100...'}) + + # Read current value from another component + current_value = await ws.get_prop('input-field', 'value') + + return f"Completed! Input was: {current_value}" +``` + +**API:** +- `set_props(component_id, props_dict)` - Stream prop updates immediately to client +- `ctx.get_websocket()` - Get WebSocket interface (returns `None` if not in WS context) +- `await ws.get_prop(component_id, prop_name)` - Read current prop value from client +- `await ws.set_prop(component_id, prop_name, value)` - Set single prop (async version) +- `await ws.close(code, reason)` - Close the WebSocket connection + +### Connection Hooks + +Use hooks to validate connections and messages: + +```python +from dash import Dash, hooks + +@hooks.websocket_connect() +async def validate_connection(websocket): + """Validate WebSocket connection before accepting.""" + session_id = websocket.cookies.get("session_id") + if not session_id: + return (4001, "No session cookie") + if not await is_valid_session(session_id): + return (4002, "Invalid session") + return True # Allow connection + +@hooks.websocket_message() +async def validate_message(websocket, message): + """Validate each WebSocket message.""" + session_id = websocket.cookies.get("session_id") + if not await is_session_active(session_id): + return (4002, "Session expired") + return True # Allow message +``` + +**Hook Return Values:** +- `True` (or truthy) - Allow connection/message +- `False` - Reject with default code (4001) +- `(code, reason)` - Reject with custom close code and reason + +### Key Files + +- `dash/dash.py` - WebSocket config in `_generate_config()` +- `dash/dash-renderer/src/utils/workerClient.ts` - Browser-side SharedWorker client +- `@plotly/dash-websocket-worker/src/WebSocketManager.ts` - WebSocket connection management +- `@plotly/dash-websocket-worker/src/worker.ts` - SharedWorker entry point +- `dash/backends/_fastapi.py` - Server-side WebSocket handler + ## Security ### XSS Protection diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index b15bba3a3b..80d84a8188 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -19,6 +19,7 @@ jobs: backend_cb_changed: ${{ steps.filter.outputs.backend_paths }} dcc_paths_changed: ${{ steps.filter.outputs.dcc_related_paths }} html_paths_changed: ${{ steps.filter.outputs.html_related_paths }} + websocket_changed: ${{ steps.filter.outputs.websocket_paths }} steps: - name: Checkout repository uses: actions/checkout@v4 @@ -48,6 +49,18 @@ jobs: backend_paths: - 'dash/backends/**' - 'tests/backend_tests/**' + websocket_paths: + - 'dash/backends/_fastapi.py' + - 'dash/backends/_quart.py' + - 'dash/backends/base_server.py' + - 'dash/_callback.py' + - 'dash/_callback_context.py' + - 'dash/_hooks.py' + - 'dash/dash.py' + - '@dash-websocket-worker/**' + - 'dash/dash-renderer/src/**' + - 'tests/websocket/**' + - 'requirements/**' lint-unit: name: Lint & Unit Tests (Python ${{ matrix.python-version }}) @@ -366,7 +379,7 @@ jobs: - name: Set up Node.js uses: actions/setup-node@v4 with: - node-version: '20' + node-version: '24' cache: 'npm' - name: Install Node.js dependencies @@ -377,6 +390,7 @@ jobs: with: python-version: ${{ matrix.python-version }} cache: 'pip' + cache-dependency-path: requirements/*.txt - name: Download built Dash packages uses: actions/download-artifact@v4 @@ -387,43 +401,13 @@ jobs: - name: Install Dash packages run: | python -m pip install --upgrade pip wheel - python -m pip install "setuptools<78.0.0" - python -m pip install "selenium==4.32.0" + python -m pip install "setuptools<80.0.0" find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[async,ci,testing,dev,celery,diskcache,fastapi,quart]"' \; - - name: Install Google Chrome - run: | - sudo apt-get update - sudo apt-get install -y google-chrome-stable - - - name: Install ChromeDriver - run: | - echo "Determining Chrome version..." - CHROME_BROWSER_VERSION=$(google-chrome --version) - echo "Installed Chrome Browser version: $CHROME_BROWSER_VERSION" - CHROME_MAJOR_VERSION=$(echo "$CHROME_BROWSER_VERSION" | cut -f 3 -d ' ' | cut -f 1 -d '.') - echo "Detected Chrome Major version: $CHROME_MAJOR_VERSION" - if [ "$CHROME_MAJOR_VERSION" -ge 115 ]; then - echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using CfT endpoint..." - CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://googlechromelabs.github.io/chrome-for-testing/LATEST_RELEASE_${CHROME_MAJOR_VERSION}") - if [ -z "$CHROMEDRIVER_VERSION_STRING" ]; then - echo "Could not automatically find ChromeDriver version for Chrome $CHROME_MAJOR_VERSION via LATEST_RELEASE. Please check CfT endpoints." - exit 1 - fi - CHROMEDRIVER_URL="https://edgedl.me.gvt1.com/edgedl/chrome/chrome-for-testing/${CHROMEDRIVER_VERSION_STRING}/linux64/chromedriver-linux64.zip" - else - echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using older method..." - CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://chromedriver.storage.googleapis.com/LATEST_RELEASE_${CHROME_MAJOR_VERSION}") - CHROMEDRIVER_URL="https://chromedriver.storage.googleapis.com/${CHROMEDRIVER_VERSION_STRING}/chromedriver_linux64.zip" - fi - echo "Using ChromeDriver version string: $CHROMEDRIVER_VERSION_STRING" - echo "Downloading ChromeDriver from: $CHROMEDRIVER_URL" - wget -q -O chromedriver.zip "$CHROMEDRIVER_URL" - unzip -o chromedriver.zip -d /tmp/ - sudo mv /tmp/chromedriver-linux64/chromedriver /usr/local/bin/chromedriver || sudo mv /tmp/chromedriver /usr/local/bin/chromedriver - sudo chmod +x /usr/local/bin/chromedriver - echo "/usr/local/bin" >> $GITHUB_PATH - shell: bash + - name: Setup Chrome and ChromeDriver + uses: browser-actions/setup-chrome@v1 + with: + chrome-version: stable - name: Build/Setup test components run: npm run setup-tests.py @@ -558,6 +542,67 @@ jobs: path: components/dash-table/test-reports/ retention-days: 7 + websocket-tests: + name: WebSocket Tests (Python ${{ matrix.python-version }}) + needs: [build, changes_filter] + if: | + (github.event_name == 'push' && (github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev')) || + needs.changes_filter.outputs.websocket_changed == 'true' + timeout-minutes: 30 + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.12"] + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '24' + cache: 'npm' + + - name: Install Node.js dependencies + run: npm ci + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: requirements/*.txt + + - name: Download built Dash packages + uses: actions/download-artifact@v4 + with: + name: dash-packages + path: packages/ + + - name: Install Dash packages + run: | + python -m pip install --upgrade pip wheel + python -m pip install "setuptools<80.0.0" + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[ci,testing,dev,fastapi,quart]"' \; + + - name: Setup Chrome and ChromeDriver + uses: browser-actions/setup-chrome@v1 + with: + chrome-version: stable + + - name: Build/Setup test components + run: npm run setup-tests.py + + - name: Run WebSocket tests + run: | + mkdir wstests + cp -r tests wstests/tests + cd wstests + touch __init__.py + pytest --headless --nopercyfinalize tests/websocket -v -s + test-main: name: Main Dash Tests (Python ${{ matrix.python-version }}, Group ${{ matrix.test-group }}) needs: build diff --git a/@plotly/dash-websocket-worker/README.md b/@plotly/dash-websocket-worker/README.md new file mode 100644 index 0000000000..64e37a1987 --- /dev/null +++ b/@plotly/dash-websocket-worker/README.md @@ -0,0 +1,3 @@ +# Dash websocket worker + +Worker for websocket based callbacks. diff --git a/@plotly/dash-websocket-worker/package.json b/@plotly/dash-websocket-worker/package.json new file mode 100644 index 0000000000..619a842380 --- /dev/null +++ b/@plotly/dash-websocket-worker/package.json @@ -0,0 +1,29 @@ +{ + "name": "@plotly/dash-websocket-worker", + "version": "1.0.0", + "description": "SharedWorker for WebSocket-based Dash callbacks", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "scripts": { + "build": "webpack --mode production", + "build:dev": "webpack --mode development", + "watch": "webpack --mode development --watch", + "clean": "rm -rf dist" + }, + "files": [ + "dist" + ], + "keywords": [ + "dash", + "websocket", + "sharedworker" + ], + "author": "Plotly", + "license": "MIT", + "devDependencies": { + "typescript": "^5.0.0", + "webpack": "^5.0.0", + "webpack-cli": "^5.0.0", + "ts-loader": "^9.0.0" + } +} diff --git a/@plotly/dash-websocket-worker/src/MessageRouter.ts b/@plotly/dash-websocket-worker/src/MessageRouter.ts new file mode 100644 index 0000000000..68a9f4bfc2 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/MessageRouter.ts @@ -0,0 +1,207 @@ +import { + WorkerMessageType, + WorkerMessage, + CallbackRequestMessage, + GetPropsResponseMessage, + SetPropsMessage, + GetPropsRequestMessage, + CallbackResponseMessage +} from './types'; + +/** + * Routes messages between renderers (via MessagePorts) and the WebSocket server. + */ +export class MessageRouter { + /** Map of renderer IDs to their MessagePorts */ + private renderers: Map = new Map(); + + /** Callback to send messages to the WebSocket server */ + public sendToServer: ((message: unknown) => void) | null = null; + + /** + * Register a renderer with its MessagePort. + * @param rendererId Unique identifier for the renderer + * @param port The MessagePort for communication + */ + public registerRenderer(rendererId: string, port: MessagePort): void { + this.renderers.set(rendererId, port); + } + + /** + * Unregister a renderer. + * @param rendererId The renderer to unregister + */ + public unregisterRenderer(rendererId: string): void { + this.renderers.delete(rendererId); + } + + /** + * Get the number of connected renderers. + */ + public get rendererCount(): number { + return this.renderers.size; + } + + /** + * Handle a message from a renderer. + * @param rendererId The ID of the renderer that sent the message + * @param message The message from the renderer + */ + public handleRendererMessage(rendererId: string, message: WorkerMessage): void { + switch (message.type) { + case WorkerMessageType.CALLBACK_REQUEST: + this.forwardCallbackRequest(rendererId, message as CallbackRequestMessage); + break; + + case WorkerMessageType.GET_PROPS_RESPONSE: + this.forwardGetPropsResponse(rendererId, message as GetPropsResponseMessage); + break; + + default: + console.warn(`Unknown message type from renderer: ${message.type}`); + } + } + + /** + * Handle a message from the WebSocket server. + * @param message The message from the server + */ + public handleServerMessage(message: unknown): void { + const msg = message as WorkerMessage; + const rendererId = msg.rendererId; + + switch (msg.type) { + case WorkerMessageType.CALLBACK_RESPONSE: + this.forwardToRenderer(rendererId, msg as CallbackResponseMessage); + break; + + case WorkerMessageType.SET_PROPS: + this.forwardSetProps(rendererId, msg as SetPropsMessage); + break; + + case WorkerMessageType.GET_PROPS_REQUEST: + this.forwardGetPropsRequest(rendererId, msg as GetPropsRequestMessage); + break; + + case WorkerMessageType.ERROR: + this.forwardToRenderer(rendererId, msg); + break; + + default: + console.warn(`Unknown message type from server: ${msg.type}`); + } + } + + /** + * Send a message to all connected renderers. + * @param message The message to broadcast + */ + public broadcastToRenderers(message: WorkerMessage): void { + for (const [rendererId, port] of this.renderers) { + try { + port.postMessage(message); + } catch (error) { + // Port may be closed if tab was closed + console.warn(`Failed to send to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } + } + + /** + * Send a connected notification to a specific renderer. + * @param rendererId The renderer to notify + */ + public notifyConnected(rendererId: string): void { + const port = this.renderers.get(rendererId); + if (port) { + try { + port.postMessage({ + type: WorkerMessageType.CONNECTED, + rendererId + }); + } catch (error) { + console.warn(`Failed to notify renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } + } + + /** + * Send a disconnected notification to all renderers. + * @param reason Optional reason for disconnection + */ + public notifyDisconnected(reason?: string): void { + this.broadcastToRenderers({ + type: WorkerMessageType.DISCONNECTED, + rendererId: '', + payload: { reason } + }); + } + + /** + * Send an error notification to a specific renderer. + * @param rendererId The renderer to notify + * @param message Error message + * @param code Optional error code + */ + public notifyError(rendererId: string, message: string, code?: string): void { + const port = this.renderers.get(rendererId); + if (port) { + try { + port.postMessage({ + type: WorkerMessageType.ERROR, + rendererId, + payload: { message, code } + }); + } catch (error) { + console.warn(`Failed to send error to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } + } + + private forwardCallbackRequest(rendererId: string, message: CallbackRequestMessage): void { + if (this.sendToServer) { + this.sendToServer({ + type: WorkerMessageType.CALLBACK_REQUEST, + rendererId, + requestId: message.requestId, + payload: message.payload + }); + } + } + + private forwardGetPropsResponse(rendererId: string, message: GetPropsResponseMessage): void { + if (this.sendToServer) { + this.sendToServer({ + type: WorkerMessageType.GET_PROPS_RESPONSE, + rendererId, + requestId: message.requestId, + payload: message.payload + }); + } + } + + private forwardToRenderer(rendererId: string, message: WorkerMessage): void { + const port = this.renderers.get(rendererId); + if (port) { + try { + port.postMessage(message); + } catch (error) { + console.warn(`Failed to forward to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } + } else { + console.warn(`Renderer ${rendererId} not found for message`); + } + } + + private forwardSetProps(rendererId: string, message: SetPropsMessage): void { + this.forwardToRenderer(rendererId, message); + } + + private forwardGetPropsRequest(rendererId: string, message: GetPropsRequestMessage): void { + this.forwardToRenderer(rendererId, message); + } +} diff --git a/@plotly/dash-websocket-worker/src/WebSocketManager.ts b/@plotly/dash-websocket-worker/src/WebSocketManager.ts new file mode 100644 index 0000000000..f7abe18dda --- /dev/null +++ b/@plotly/dash-websocket-worker/src/WebSocketManager.ts @@ -0,0 +1,301 @@ +/** + * Configuration options for WebSocket connection. + */ +interface WebSocketConfig { + /** Maximum number of reconnection attempts */ + maxRetries: number; + /** Initial delay between reconnection attempts (ms) */ + initialRetryDelay: number; + /** Maximum delay between reconnection attempts (ms) */ + maxRetryDelay: number; + /** Heartbeat interval (ms) */ + heartbeatInterval: number; + /** Heartbeat timeout (ms) */ + heartbeatTimeout: number; + /** Inactivity timeout (ms) - 0 to disable */ + inactivityTimeout: number; +} + +const DEFAULT_CONFIG: WebSocketConfig = { + maxRetries: 10, + initialRetryDelay: 1000, + maxRetryDelay: 30000, + heartbeatInterval: 30000, + heartbeatTimeout: 10000, + inactivityTimeout: 300000 // 5 minutes default +}; + +/** + * Manages WebSocket connection with automatic reconnection and heartbeat. + */ +export class WebSocketManager { + private ws: WebSocket | null = null; + private serverUrl: string | null = null; + private config: WebSocketConfig; + private retryCount = 0; + private retryTimeout: ReturnType | null = null; + private heartbeatInterval: ReturnType | null = null; + private heartbeatTimeout: ReturnType | null = null; + private lastActivityTime: number = Date.now(); + private messageQueue: string[] = []; + private isConnecting = false; + + /** Callback when connection is established */ + public onOpen: (() => void) | null = null; + /** Callback when connection is closed */ + public onClose: ((reason?: string) => void) | null = null; + /** Callback when a message is received */ + public onMessage: ((data: unknown) => void) | null = null; + /** Callback when an error occurs */ + public onError: ((error: Error) => void) | null = null; + + constructor(config: Partial = {}) { + this.config = { ...DEFAULT_CONFIG, ...config }; + } + + /** + * Update configuration options. + * Only updates the provided options, keeping others unchanged. + * @param config Partial configuration to merge + */ + public setConfig(config: Partial): void { + this.config = { ...this.config, ...config }; + } + + /** + * Connect to the WebSocket server. + * @param serverUrl The WebSocket server URL + */ + public connect(serverUrl: string): void { + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + // Already connected + return; + } + + if (this.isConnecting) { + // Connection in progress + return; + } + + this.serverUrl = serverUrl; + this.isConnecting = true; + this.createConnection(); + } + + /** + * Disconnect from the WebSocket server. + */ + public disconnect(): void { + this.cleanup(); + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.close(1000, 'Client disconnect'); + } + this.ws = null; + this.serverUrl = null; + this.retryCount = 0; + } + + /** + * Send a message through the WebSocket connection. + * If not connected, queues the message and triggers reconnection. + * @param message The message to send + */ + public send(message: unknown): void { + const data = JSON.stringify(message); + + // Track activity for non-heartbeat messages + const msgObj = message as { type?: string }; + if (msgObj.type !== 'heartbeat') { + this.lastActivityTime = Date.now(); + } + + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.send(data); + } else { + // Queue message for when connection is established + this.messageQueue.push(data); + + // Trigger reconnect if we have a server URL but aren't connected/connecting + if (this.serverUrl && !this.isConnecting) { + this.isConnecting = true; + // Reset retry count since this is user-initiated activity + this.retryCount = 0; + this.createConnection(); + } + } + } + + /** + * Check if the WebSocket is currently connected. + */ + public get isConnected(): boolean { + return this.ws !== null && this.ws.readyState === WebSocket.OPEN; + } + + private createConnection(): void { + if (!this.serverUrl) { + return; + } + + try { + this.ws = new WebSocket(this.serverUrl); + this.ws.onopen = this.handleOpen.bind(this); + this.ws.onclose = this.handleClose.bind(this); + this.ws.onmessage = this.handleMessage.bind(this); + this.ws.onerror = this.handleError.bind(this); + } catch (error) { + this.isConnecting = false; + this.scheduleReconnect(); + } + } + + private handleOpen(): void { + this.isConnecting = false; + this.retryCount = 0; + this.lastActivityTime = Date.now(); + + // Flush queued messages + while (this.messageQueue.length > 0) { + const message = this.messageQueue.shift(); + if (message && this.ws) { + this.ws.send(message); + } + } + + // Start heartbeat (also handles inactivity check) + this.startHeartbeat(); + + if (this.onOpen) { + this.onOpen(); + } + } + + private handleClose(event: CloseEvent): void { + this.isConnecting = false; + this.cleanup(); + + const reason = event.reason || 'Connection closed'; + + if (this.onClose) { + this.onClose(reason); + } + + // Only reconnect if: + // - We haven't explicitly disconnected (code 1000) + // - It's not an inactivity timeout (code 4001) + if (this.serverUrl && event.code !== 1000 && event.code !== 4001) { + this.scheduleReconnect(); + } + } + + private handleMessage(event: MessageEvent): void { + try { + const data = JSON.parse(event.data); + + // Handle heartbeat acknowledgment - does NOT count as activity + if (data.type === 'heartbeat_ack') { + this.clearHeartbeatTimeout(); + return; + } + + // Track activity for actual callback messages + this.lastActivityTime = Date.now(); + + if (this.onMessage) { + this.onMessage(data); + } + } catch (error) { + if (this.onError) { + this.onError(new Error('Failed to parse message')); + } + } + } + + private handleError(): void { + this.isConnecting = false; + // WebSocket error events don't contain useful information + // The close event will follow with more details + } + + private scheduleReconnect(): void { + if (this.retryTimeout) { + clearTimeout(this.retryTimeout); + } + + if (this.retryCount >= this.config.maxRetries) { + if (this.onError) { + this.onError(new Error('Max reconnection attempts reached')); + } + return; + } + + // Exponential backoff with jitter + const delay = Math.min( + this.config.initialRetryDelay * Math.pow(2, this.retryCount) + + Math.random() * 1000, + this.config.maxRetryDelay + ); + + this.retryCount++; + + this.retryTimeout = setTimeout(() => { + this.createConnection(); + }, delay); + } + + private startHeartbeat(): void { + this.stopHeartbeat(); + + this.heartbeatInterval = setInterval(() => { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + return; + } + + // Check for inactivity timeout + if (this.config.inactivityTimeout > 0) { + const timeSinceActivity = Date.now() - this.lastActivityTime; + if (timeSinceActivity >= this.config.inactivityTimeout) { + this.ws.close(4001, 'Inactivity timeout'); + return; + } + } + + this.ws.send(JSON.stringify({ type: 'heartbeat' })); + this.setHeartbeatTimeout(); + }, this.config.heartbeatInterval); + } + + private stopHeartbeat(): void { + if (this.heartbeatInterval) { + clearInterval(this.heartbeatInterval); + this.heartbeatInterval = null; + } + this.clearHeartbeatTimeout(); + } + + private setHeartbeatTimeout(): void { + this.clearHeartbeatTimeout(); + + this.heartbeatTimeout = setTimeout(() => { + // Heartbeat timeout - connection may be dead + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.close(4000, 'Heartbeat timeout'); + } + }, this.config.heartbeatTimeout); + } + + private clearHeartbeatTimeout(): void { + if (this.heartbeatTimeout) { + clearTimeout(this.heartbeatTimeout); + this.heartbeatTimeout = null; + } + } + + private cleanup(): void { + this.stopHeartbeat(); + if (this.retryTimeout) { + clearTimeout(this.retryTimeout); + this.retryTimeout = null; + } + } +} diff --git a/@plotly/dash-websocket-worker/src/index.ts b/@plotly/dash-websocket-worker/src/index.ts new file mode 100644 index 0000000000..e21b382d41 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/index.ts @@ -0,0 +1,18 @@ +/** + * Dash WebSocket Worker Package + * + * Provides a SharedWorker for WebSocket-based Dash callbacks. + */ + +export * from './types'; + +/** + * Get the URL for the WebSocket worker script. + * This should be used to instantiate the SharedWorker. + * + * @param baseUrl Base URL where the worker script is served + * @returns Full URL to the worker script + */ +export function getWorkerUrl(baseUrl: string): string { + return `${baseUrl}/dash-ws-worker.js`; +} diff --git a/@plotly/dash-websocket-worker/src/types.ts b/@plotly/dash-websocket-worker/src/types.ts new file mode 100644 index 0000000000..fac282b5e1 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/types.ts @@ -0,0 +1,151 @@ +/** + * Message types for communication between renderer and worker. + */ +export enum WorkerMessageType { + // Renderer -> Worker + CONNECT = 'connect', + DISCONNECT = 'disconnect', + CALLBACK_REQUEST = 'callback_request', + GET_PROPS_RESPONSE = 'get_props_response', + + // Worker -> Renderer + CONNECTED = 'connected', + DISCONNECTED = 'disconnected', + CALLBACK_RESPONSE = 'callback_response', + SET_PROPS = 'set_props', + GET_PROPS_REQUEST = 'get_props_request', + ERROR = 'error' +} + +/** + * Base message structure for worker communication. + */ +export interface WorkerMessage { + type: WorkerMessageType; + rendererId: string; + requestId?: string; + payload?: unknown; +} + +/** + * Message from renderer to worker requesting connection. + */ +export interface ConnectMessage extends WorkerMessage { + type: WorkerMessageType.CONNECT; + payload: { + serverUrl: string; + inactivityTimeout?: number; + }; +} + +/** + * Message from renderer to worker requesting disconnect. + */ +export interface DisconnectMessage extends WorkerMessage { + type: WorkerMessageType.DISCONNECT; +} + +/** + * Callback request payload structure. + */ +export interface CallbackPayload { + output: string; + outputs: unknown[]; + inputs: unknown[]; + state?: unknown[]; + changedPropIds: string[]; + parsedChangedPropsIds?: string[]; +} + +/** + * Message from renderer to worker with callback request. + */ +export interface CallbackRequestMessage extends WorkerMessage { + type: WorkerMessageType.CALLBACK_REQUEST; + payload: CallbackPayload; +} + +/** + * Message from worker to renderer with callback response. + */ +export interface CallbackResponseMessage extends WorkerMessage { + type: WorkerMessageType.CALLBACK_RESPONSE; + payload: { + status: 'ok' | 'prevent_update' | 'error'; + data?: Record; + message?: string; + }; +} + +/** + * Message from worker to renderer to set component props. + */ +export interface SetPropsMessage extends WorkerMessage { + type: WorkerMessageType.SET_PROPS; + payload: { + componentId: string; + props: Record; + }; +} + +/** + * Message from worker to renderer requesting prop values. + */ +export interface GetPropsRequestMessage extends WorkerMessage { + type: WorkerMessageType.GET_PROPS_REQUEST; + payload: { + componentId: string; + properties: string[]; + }; +} + +/** + * Message from renderer to worker with prop values. + */ +export interface GetPropsResponseMessage extends WorkerMessage { + type: WorkerMessageType.GET_PROPS_RESPONSE; + payload: Record; +} + +/** + * Error message from worker to renderer. + */ +export interface ErrorMessage extends WorkerMessage { + type: WorkerMessageType.ERROR; + payload: { + message: string; + code?: string; + }; +} + +/** + * Connected confirmation message from worker to renderer. + */ +export interface ConnectedMessage extends WorkerMessage { + type: WorkerMessageType.CONNECTED; +} + +/** + * Disconnected notification message from worker to renderer. + */ +export interface DisconnectedMessage extends WorkerMessage { + type: WorkerMessageType.DISCONNECTED; + payload?: { + reason?: string; + }; +} + +/** + * Union type of all possible worker messages. + */ +export type AnyWorkerMessage = + | ConnectMessage + | DisconnectMessage + | CallbackRequestMessage + | CallbackResponseMessage + | SetPropsMessage + | GetPropsRequestMessage + | GetPropsResponseMessage + | ErrorMessage + | ConnectedMessage + | DisconnectedMessage; diff --git a/@plotly/dash-websocket-worker/src/worker.ts b/@plotly/dash-websocket-worker/src/worker.ts new file mode 100644 index 0000000000..0e68f0b09a --- /dev/null +++ b/@plotly/dash-websocket-worker/src/worker.ts @@ -0,0 +1,135 @@ +/** + * Dash WebSocket Worker + * + * A SharedWorker that maintains a single WebSocket connection to the Dash server + * and routes messages between multiple renderer instances (browser tabs). + */ + +import { WebSocketManager } from './WebSocketManager'; +import { MessageRouter } from './MessageRouter'; +import { + WorkerMessageType, + WorkerMessage, + ConnectMessage +} from './types'; + +// SharedWorker global scope +declare const self: SharedWorkerGlobalScope; + +/** WebSocket connection manager */ +const wsManager = new WebSocketManager(); + +/** Message router for renderers */ +const router = new MessageRouter(); + +/** Current server URL */ +let serverUrl: string | null = null; + +/** + * Set up WebSocket manager callbacks. + */ +wsManager.onOpen = () => { + console.log('[DashWSWorker] WebSocket connected'); + // Notify all renderers that connection is established + for (const rendererId of getRendererIds()) { + router.notifyConnected(rendererId); + } +}; + +wsManager.onClose = (reason?: string) => { + console.log(`[DashWSWorker] WebSocket closed: ${reason}`); + router.notifyDisconnected(reason); +}; + +wsManager.onMessage = (data: unknown) => { + router.handleServerMessage(data); +}; + +wsManager.onError = (error: Error) => { + console.error('[DashWSWorker] WebSocket error:', error.message); +}; + +/** + * Set up router to send messages to WebSocket. + */ +router.sendToServer = (message: unknown) => { + wsManager.send(message); +}; + +// Track renderer IDs separately for iteration +const rendererIds = new Set(); + +/** + * Get all registered renderer IDs. + */ +function getRendererIds(): string[] { + return Array.from(rendererIds); +} + +/** + * Handle new connection from a renderer (browser tab). + */ +self.onconnect = (event: MessageEvent) => { + const port = event.ports[0]; + + port.onmessage = (e: MessageEvent) => { + const message = e.data as WorkerMessage; + + switch (message.type) { + case WorkerMessageType.CONNECT: { + const connectMsg = message as ConnectMessage; + const rendererId = connectMsg.rendererId; + const newServerUrl = connectMsg.payload.serverUrl; + const inactivityTimeout = connectMsg.payload.inactivityTimeout; + + // Register the renderer + router.registerRenderer(rendererId, port); + rendererIds.add(rendererId); + + console.log(`[DashWSWorker] Renderer ${rendererId} connected, inactivityTimeout: ${inactivityTimeout}`); + + // Update inactivity timeout if provided + if (typeof inactivityTimeout === 'number') { + wsManager.setConfig({ inactivityTimeout }); + } + + // Connect to server if not already connected + if (!wsManager.isConnected) { + if (serverUrl !== newServerUrl) { + serverUrl = newServerUrl; + } + wsManager.connect(serverUrl); + } else { + // Already connected, notify the renderer + router.notifyConnected(rendererId); + } + break; + } + + case WorkerMessageType.DISCONNECT: { + const rendererId = message.rendererId; + router.unregisterRenderer(rendererId); + rendererIds.delete(rendererId); + + console.log(`[DashWSWorker] Renderer ${rendererId} disconnected`); + + // If no more renderers, disconnect from server + if (router.rendererCount === 0) { + wsManager.disconnect(); + serverUrl = null; + console.log('[DashWSWorker] All renderers disconnected, closing WebSocket'); + } + break; + } + + default: + // Forward other messages through the router + router.handleRendererMessage(message.rendererId, message); + } + }; + + port.start(); +}; + +// Log worker startup +console.log('[DashWSWorker] SharedWorker initialized'); diff --git a/@plotly/dash-websocket-worker/tsconfig.json b/@plotly/dash-websocket-worker/tsconfig.json new file mode 100644 index 0000000000..0254db7f91 --- /dev/null +++ b/@plotly/dash-websocket-worker/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "target": "ES2020", + "module": "ESNext", + "lib": ["ES2020", "WebWorker"], + "declaration": true, + "declarationMap": true, + "sourceMap": true, + "outDir": "./dist", + "rootDir": "./src", + "strict": true, + "moduleResolution": "node", + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist"] +} diff --git a/@plotly/dash-websocket-worker/webpack.config.js b/@plotly/dash-websocket-worker/webpack.config.js new file mode 100644 index 0000000000..efe7b59e89 --- /dev/null +++ b/@plotly/dash-websocket-worker/webpack.config.js @@ -0,0 +1,25 @@ +const path = require('path'); + +// This config is for standalone development/testing of the worker. +// The production build is handled by dash-renderer's webpack config. +module.exports = { + entry: './src/worker.ts', + output: { + filename: 'dash-ws-worker.js', + path: path.resolve(__dirname, 'dist'), + clean: true + }, + resolve: { + extensions: ['.ts', '.js'] + }, + module: { + rules: [ + { + test: /\.ts$/, + use: 'ts-loader', + exclude: /node_modules/ + } + ] + }, + target: 'webworker' +}; diff --git a/CHANGELOG.md b/CHANGELOG.md index be8d80eaf3..b719972684 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,18 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#3723](https://github.com/plotly/dash/pull/3723) Fix misaligned `dcc.Slider` marks when some labels are empty strings - [#3740](https://github.com/plotly/dash/pull/3740) Fix cannot tab into dropdowns in Safari - [#2462](https://github.com/plotly/dash/issues/2462) Allow `MATCH` in `Input`/`State` when the callback's `Output` has no wildcards (fixed-id Output, no Output, or `ALL`-only wildcard Output). `ALLSMALLER` still requires a corresponding `MATCH` in an Output. +- [#3759](https://github.com/plotly/dash/pull/3759) Fix the issue where `Patch` objects cannot be updated via `set_props()` in `websocket` callback. Fix [#3742](https://github.com/plotly/dash/issues/3742) + +## [4.2.0rc1] - 2026-05-01 + +## Fixed +- [#3759](https://github.com/plotly/dash/pull/3759) Fix the error when using `set_props()` to update component-type properties in the `websocket` callback. +- Add threadpool for running websocket callbacks. + +## [4.2.0rc1] - 2026-04-13 + +## Added +- [#3742](https://github.com/plotly/dash/pull/3742) Add websocket callbacks to fastapi and quart backends. ## [4.1.0] - 2026-03-23 diff --git a/dash/_callback.py b/dash/_callback.py index 37a53d7ec5..718a016d82 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -77,6 +77,7 @@ def callback( api_endpoint: Optional[str] = None, optional: Optional[bool] = False, hidden: Optional[bool] = None, + websocket: Optional[bool] = False, **_kwargs, ) -> Callable[..., Any]: """ @@ -228,6 +229,7 @@ def callback( api_endpoint=api_endpoint, optional=optional, hidden=hidden, + websocket=websocket, ) @@ -275,6 +277,7 @@ def insert_callback( no_output=False, optional=False, hidden=None, + websocket=False, ) -> str: if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks @@ -300,6 +303,7 @@ def insert_callback( "no_output": no_output, "optional": optional, "hidden": hidden, + "websocket": websocket, } if running: callback_spec["running"] = running @@ -315,6 +319,7 @@ def insert_callback( "manager": manager, "allow_dynamic_callbacks": dynamic_creator, "no_output": no_output, + "websocket": websocket, } callback_list.append(callback_spec) @@ -652,6 +657,7 @@ def register_callback( no_output=not has_output, optional=_kwargs.get("optional", False), hidden=_kwargs.get("hidden", None), + websocket=_kwargs.get("websocket", False), ) # pylint: disable=too-many-locals diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 646db990ab..e03f343129 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -1,9 +1,12 @@ +import asyncio import functools import warnings import json import contextvars import typing +from dash.backends.ws import DashWebsocketCallback + from . import exceptions from ._get_app import get_app from ._utils import AttributeDict, stringify_id @@ -323,6 +326,32 @@ def custom_data(self): """ return _get_from_context("custom_data", {}) + @property + @has_context + def get_websocket(self) -> typing.Optional[DashWebsocketCallback]: + """Get WebSocket interface if running in WebSocket context. + + Returns the DashWebsocketCallback instance if the callback is being + executed via WebSocket, otherwise returns None. + + Raises: + RuntimeError: If websocket_callbacks is requested but the backend + doesn't support WebSocket. + """ + ws = _get_from_context("dash_websocket", None) + if ws is None: + app = get_app() + if ( + hasattr(app, "_websocket_callbacks") + and app._websocket_callbacks # pylint: disable=protected-access + and not app.backend.websocket_capability + ): + raise RuntimeError( + f"WebSocket callbacks requested but backend " + f"'{app.backend.server_type}' doesn't support them." + ) + return ws + callback_context = CallbackContext() @@ -330,5 +359,26 @@ def custom_data(self): def set_props(component_id: typing.Union[str, dict], props: dict): """ Set the props for a component not included in the callback outputs. + + If running in a WebSocket context, props are streamed immediately to the + client. Otherwise, props are batched and sent with the callback response. """ - callback_context.set_props(component_id, props) + ws = _get_from_context("dash_websocket", None) + if ws is not None: + # Stream immediately via WebSocket + _id = stringify_id(component_id) + + async def _send_props(): + for prop_name, value in props.items(): + await ws.set_prop(_id, prop_name, value) + + # If we're in an async context, schedule the coroutine + try: + asyncio.get_running_loop() + asyncio.ensure_future(_send_props()) + except RuntimeError: + # No running event loop - run synchronously + asyncio.run(_send_props()) + else: + # Batch for response (existing behavior) + callback_context.set_props(component_id, props) diff --git a/dash/_dash_renderer.py b/dash/_dash_renderer.py index ee507ddb71..5574131d10 100644 --- a/dash/_dash_renderer.py +++ b/dash/_dash_renderer.py @@ -1,7 +1,7 @@ import os from typing import Any, List, Dict -__version__ = "3.0.0" +__version__ = "3.1.0" _available_react_versions = {"18.3.1", "18.2.0", "16.14.0"} _available_reactdom_versions = {"18.3.1", "18.2.0", "16.14.0"} @@ -65,7 +65,7 @@ def _set_react_version(v_react, v_reactdom=None): { "relative_package_path": "dash-renderer/build/dash_renderer.min.js", "dev_package_path": "dash-renderer/build/dash_renderer.dev.js", - "external_url": "https://unpkg.com/dash-renderer@3.0.0" + "external_url": "https://unpkg.com/dash-renderer@3.1.0" "/build/dash_renderer.min.js", "namespace": "dash", }, @@ -75,4 +75,9 @@ def _set_react_version(v_react, v_reactdom=None): "namespace": "dash", "dynamic": True, }, + { + "relative_package_path": "dash-renderer/build/dash-ws-worker.js", + "namespace": "dash", + "dynamic": True, + }, ] diff --git a/dash/_hooks.py b/dash/_hooks.py index 1631b40ddc..f260b1fcb0 100644 --- a/dash/_hooks.py +++ b/dash/_hooks.py @@ -49,6 +49,8 @@ def __init__(self) -> None: "index": [], "custom_data": [], "dev_tools": [], + "websocket_connect": [], + "websocket_message": [], } self._js_dist: _t.List[_t.Any] = [] self._css_dist: _t.List[_t.Any] = [] @@ -244,6 +246,60 @@ def devtool( } ) + def websocket_connect(self, priority: _t.Optional[int] = None, final: bool = False): + """ + Register a WebSocket connection validation hook. + + The hook receives the WebSocket object and should return: + - True (or any truthy value): Allow the connection + - False: Reject with default code (4001) and reason + - tuple (code, reason): Reject with custom close code and reason + + Hooks can be sync or async. + + Example: + @hooks.websocket_connect() + async def validate_session(websocket): + session_id = websocket.cookies.get("session_id") + if not session_id: + return (4001, "No session cookie") + if not await is_valid_session(session_id): + return (4002, "Invalid session") + return True + """ + + def decorator(func: _t.Callable): + self.add_hook("websocket_connect", func, priority=priority, final=final) + return func + + return decorator + + def websocket_message(self, priority: _t.Optional[int] = None, final: bool = False): + """ + Register a WebSocket message validation hook. + + The hook receives the WebSocket object and message dict, and should return: + - True (or any truthy value): Allow the message + - False: Disconnect with default code (4001) and reason + - tuple (code, reason): Disconnect with custom close code and reason + + Hooks can be sync or async. + + Example: + @hooks.websocket_message() + async def validate_session(websocket, message): + session_id = websocket.cookies.get("session_id") + if not await is_session_active(session_id): + return (4002, "Session expired") + return True + """ + + def decorator(func: _t.Callable): + self.add_hook("websocket_message", func, priority=priority, final=final) + return func + + return decorator + hooks = _Hooks() diff --git a/dash/_validate.py b/dash/_validate.py index fb5689f850..b80c61df2c 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -629,3 +629,33 @@ def check_backend(backend, inferred_backend): raise ValueError( f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." ) + + +def validate_websocket_callback_request( + callback_id, callback_map, websocket_callbacks_enabled +): + """Validate a WebSocket callback request at runtime. + + Called by WebSocket handlers to verify that a callback received via WebSocket + is actually allowed to use WebSocket transport. + + Args: + callback_id: The callback output ID from the request + callback_map: The app's callback_map dictionary + websocket_callbacks_enabled: Whether websocket_callbacks=True at app level + + Raises: + WebSocketCallbackError: If the callback is not websocket-enabled + """ + # If global websocket_callbacks is enabled, all callbacks can use WebSocket + if websocket_callbacks_enabled: + return + + # Otherwise, check if this specific callback has websocket=True + cb = callback_map.get(callback_id, {}) + if not cb.get("websocket"): + raise exceptions.WebSocketCallbackError( + f"Callback '{callback_id}' received via WebSocket but does not have " + f"websocket=True. Either enable websocket_callbacks=True globally " + f"or add websocket=True to this callback." + ) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 4e5bdad621..1fc60cd703 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -1,7 +1,10 @@ from __future__ import annotations from contextvars import copy_context, ContextVar +import asyncio +import concurrent.futures import json +import queue from typing import TYPE_CHECKING, Any, Callable, Dict import sys import mimetypes @@ -13,6 +16,7 @@ import subprocess import threading import traceback +from urllib.parse import urlparse try: from fastapi import FastAPI, Request, Response, Body @@ -21,16 +25,30 @@ from starlette.responses import Response as StarletteResponse from starlette.datastructures import MutableHeaders from starlette.types import ASGIApp, Scope, Receive, Send + from starlette.websockets import WebSocket, WebSocketDisconnect import uvicorn except ImportError as _err: raise ImportError( "All dependencies not installed. Please install it with `dash[fastapi]` to use the FastAPI backend." ) from _err +import janus + from dash.fingerprint import check_fingerprint from dash import _validate, get_app from dash.exceptions import PreventUpdate -from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter +from .base_server import ( + BaseDashServer, + RequestAdapter, + ResponseAdapter, +) +from .ws import ( + DashWebsocketCallback, + run_ws_sender, + run_callback_in_executor, + make_callback_done_handler, + SHUTDOWN_SIGNAL, +) from ._utils import format_traceback_html if TYPE_CHECKING: # pragma: no cover - typing only @@ -224,6 +242,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class FastAPIDashServer(BaseDashServer[FastAPI]): + websocket_capability: bool = True + def __init__(self, server: FastAPI): super().__init__(server) self.server_type = "fastapi" @@ -354,11 +374,14 @@ def run(self, dash_app: Dash, host, port, debug, **kwargs): # pylint: disable=R is_threaded = threading.current_thread() != threading.main_thread() if is_threaded: - # Running in a thread (testing context) - use uvicorn.run directly - # This allows the testing framework to control the server lifecycle - if kwargs.get("reload"): - kwargs["reload"] = True - uvicorn.run(self.server, host=host, port=port, **kwargs) + # Running in a thread (testing context) - use uvicorn.Server + # This allows graceful shutdown via should_exit flag + kwargs.pop("reload", None) # Reload not supported in threaded mode + config = uvicorn.Config(self.server, host=host, port=port, **kwargs) + server = uvicorn.Server(config) + # Store server reference on the app for graceful shutdown + dash_app._uvicorn_server = server # pylint: disable=protected-access + server.run() else: # Running in main thread (normal context) - use subprocess file_path = frame.filename @@ -609,6 +632,184 @@ async def timing_headers_middleware(request: Request, call_next): headers.append("Server-Timing", value) return response + async def _run_ws_hooks( + self, hooks, websocket: "WebSocket", *args, default_reason: str = "Rejected" + ) -> tuple | None: + """Run WebSocket hooks and return rejection tuple or None if all pass. + + Args: + hooks: List of hooks to run + websocket: The WebSocket connection + *args: Additional arguments to pass to hooks + default_reason: Default reason if hook returns False + + Returns: + None if all hooks pass, or (code, reason) tuple for rejection + """ + for hook in hooks: + try: + result = hook(websocket, *args) + if inspect.iscoroutine(result): + result = await result + if result is False: + return (4001, default_reason) + if isinstance(result, tuple) and len(result) == 2: + return result + except Exception: # pylint: disable=broad-exception-caught + return (4001, "Authentication error") + return None + + def serve_websocket_callback(self, dash_app: "Dash"): + """Set up the WebSocket endpoint for callback handling. + + Uses thread pool executor for callback execution with janus queues + for async/sync communication between main loop and worker threads. + + Args: + dash_app: The Dash application instance + """ + # pylint: disable=too-many-statements,too-many-locals + ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback" + + # Get allowed origins from dash app config + allowed_origins = getattr( + dash_app, "_websocket_allowed_origins", [] + ) # pylint: disable=protected-access + + def validate_origin(origin: str | None, host: str | None) -> str | None: + """Validate WebSocket origin. Returns error message or None if valid.""" + if not origin: + return "Origin header required" + if origin in allowed_origins: + return None # Explicitly allowed + if not host: + return "Origin not allowed" + # Check same-origin + origin_host = urlparse(origin).netloc + if origin_host != host: + return "Origin not allowed" + return None + + async def websocket_handler(websocket: WebSocket): + # Validate Origin header to prevent Cross-Site WebSocket Hijacking + origin = websocket.headers.get("origin") + host = websocket.headers.get("host") + error = validate_origin(origin, host) + if error: + await websocket.close(code=4003, reason=error) + return + + # Call websocket_connect hooks (before accept) + # pylint: disable=protected-access + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_connect"), + websocket, + default_reason="Connection rejected", + ) + if rejection: + await websocket.close(code=rejection[0], reason=rejection[1]) + return + + await websocket.accept() + + # Create janus queue for outbound messages (main loop context) + outbound_queue: janus.Queue[str] = janus.Queue() + # Track pending get_props requests with standard queue.Queue for responses + pending_get_props: Dict[str, queue.Queue] = {} + # Get thread pool executor + executor = self.get_callback_executor() + # Track pending callback futures + pending_callbacks: Dict[str, concurrent.futures.Future] = {} + + # Start sender task to drain outbound queue (sends pre-serialized text) + sender_task = asyncio.create_task( + run_ws_sender(websocket.send_text, outbound_queue) + ) + + try: + while True: + message = await websocket.receive_json() + + # Call websocket_message hooks + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_message"), + websocket, + message, + default_reason="Message rejected", + ) + if rejection: + await websocket.close(code=rejection[0], reason=rejection[1]) + return + + msg_type = message.get("type") + + if msg_type == "callback_request": + request_id = message.get("requestId") + renderer_id = message.get("rendererId", "") + payload = message.get("payload", {}) + + # Validate that the callback is allowed to use WebSocket transport + # pylint: disable=protected-access + _validate.validate_websocket_callback_request( + payload.get("output"), + dash_app.callback_map, + dash_app._websocket_callbacks, + ) + + # Create WebSocket callback instance with outbound queue + ws_cb = DashWebsocketCallback( + pending_get_props, renderer_id, outbound_queue + ) + + # Submit callback to executor + future = run_callback_in_executor( + executor, + dash_app, + payload, + ws_cb, + FastAPIResponseAdapter(), + ) + + # Set up done callback to send response + future.add_done_callback( + make_callback_done_handler( + outbound_queue, + pending_callbacks, + request_id, + renderer_id, + ) + ) + pending_callbacks[request_id] = future + + elif msg_type == "get_props_response": + # Put response in waiting queue (non-blocking) + request_id = message.get("requestId") + response_queue = pending_get_props.get(request_id) + if response_queue is not None: + response_queue.put_nowait(message.get("payload")) + + elif msg_type == "heartbeat": + outbound_queue.sync_q.put_nowait('{"type": "heartbeat_ack"}') + + except WebSocketDisconnect: + pass # Clean disconnect + finally: + # Signal sender to shutdown and cancel it + outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + # Close the janus queue + outbound_queue.close() + await outbound_queue.wait_closed() + # Cancel any pending futures + for f in pending_callbacks.values(): + f.cancel() + + self.server.add_api_websocket_route(ws_path, websocket_handler) + class FastAPIRequestAdapter(RequestAdapter): def __init__(self): diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index ddf31ff2f4..9441ba8bd3 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -6,10 +6,14 @@ import pkgutil import time import sys +import asyncio +import concurrent.futures +import queue +from urllib.parse import urlparse from logging.config import dictConfig from contextvars import copy_context -from typing import Any +from typing import Any, Dict, TYPE_CHECKING from importlib_metadata import version as _get_distribution_version @@ -24,19 +28,36 @@ g as quart_g, has_request_context, redirect, + websocket, ) except ImportError as _err: raise ImportError( "All dependencies not installed. Please install it with `dash[quart]` to use the Quart backend." ) from _err +import janus + from dash.exceptions import PreventUpdate, InvalidResourceError from dash.fingerprint import check_fingerprint from dash._utils import parse_version -from dash import _validate, Dash -from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter +from dash import _validate +from .base_server import ( + BaseDashServer, + RequestAdapter, + ResponseAdapter, +) +from .ws import ( + DashWebsocketCallback, + run_ws_sender, + run_callback_in_executor, + make_callback_done_handler, + SHUTDOWN_SIGNAL, +) from ._utils import format_traceback_html +if TYPE_CHECKING: + from dash import Dash + class QuartResponseAdapter(ResponseAdapter): """ @@ -67,6 +88,8 @@ def set_response(self, **kwargs): class QuartDashServer(BaseDashServer[Quart]): + websocket_capability: bool = True + def __init__(self, server: Quart) -> None: super().__init__(server) self.server_type = "quart" @@ -74,6 +97,8 @@ def __init__(self, server: Quart) -> None: self.error_handling_mode = "ignore" self.request_adapter = QuartRequestAdapter self.response_adapter = QuartResponseAdapter + self._active_websockets: set = set() + self._ws_shutdown_event: asyncio.Event | None = None def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] return self.server(*args, **kwargs) @@ -222,6 +247,15 @@ def has_request_context(self) -> bool: # pylint: disable=W0613 def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): + import signal # pylint: disable=import-outside-toplevel + import threading # pylint: disable=import-outside-toplevel + + # pylint: disable=import-outside-toplevel,import-error + from hypercorn.config import Config + from hypercorn.asyncio import serve + + # pylint: enable=import-error + self.config = {"debug": debug, **kwargs} if debug else kwargs # pylint: disable=protected-access if dash_app._dev_tools.silence_routes_logging: @@ -236,7 +270,51 @@ def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.An } ) - self.server.run(host=host, port=port, debug=debug, **kwargs) + # Check if we're running in a non-main thread (e.g., testing context) + is_main_thread = threading.current_thread() is threading.main_thread() + + config = Config() + config.bind = [f"{host}:{port}"] + config.use_reloader = False + if not is_main_thread: + config.accesslog = None + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Initialize shutdown event for WebSocket handlers + self._ws_shutdown_event = asyncio.Event() + + def signal_handler(): + """Handle shutdown signal by setting the WebSocket shutdown event.""" + if self._ws_shutdown_event is not None: + self._ws_shutdown_event.set() + + # Set up signal handlers in main thread + if is_main_thread: + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, signal_handler) + except (NotImplementedError, ValueError): + pass + + print(f" * Serving Quart app '{self.server.name}'") + print(f" * Debug mode: {debug}") + print( + " * Please use an ASGI server (e.g. Hypercorn) directly in production" + ) + print(f" * Running on http://{host}:{port} (CTRL + C to quit)") + + async def shutdown_trigger(): + if self._ws_shutdown_event is not None: + await self._ws_shutdown_event.wait() + + try: + loop.run_until_complete( + serve(self.server, config, shutdown_trigger=shutdown_trigger) + ) + finally: + loop.close() def make_response( self, @@ -385,6 +463,200 @@ def enable_compression(self) -> None: "To use the compress option, you need to install quart_compress." ) from error + async def _run_ws_hooks( + self, hooks, ws, *args, default_reason: str = "Rejected" + ) -> tuple | None: + """Run WebSocket hooks and return rejection tuple or None if all pass. + + Args: + hooks: List of hooks to run + ws: The WebSocket connection + *args: Additional arguments to pass to hooks + default_reason: Default reason if hook returns False + + Returns: + None if all hooks pass, or (code, reason) tuple for rejection + """ + for hook in hooks: + try: + result = hook(ws, *args) + if inspect.iscoroutine(result): + result = await result + if result is False: + return (4001, default_reason) + if isinstance(result, tuple) and len(result) == 2: + return result + except Exception: # pylint: disable=broad-exception-caught + return (4001, "Authentication error") + return None + + def _validate_ws_origin( + self, origin: str | None, host: str | None, allowed_origins: list + ) -> str | None: + """Validate WebSocket origin. Returns error message or None if valid.""" + if not origin: + return "Origin header required" + if origin in allowed_origins: + return None # Explicitly allowed + if not host: + return "Origin not allowed" + # Check same-origin + origin_host = urlparse(origin).netloc + if origin_host != host: + return "Origin not allowed" + return None + + def serve_websocket_callback(self, dash_app: "Dash"): + """Set up the WebSocket endpoint for callback handling. + + Uses thread pool executor for callback execution with janus queues + for async/sync communication between main loop and worker threads. + + Args: + dash_app: The Dash application instance + """ + # pylint: disable=too-many-statements,too-many-locals + ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback" + # pylint: disable=protected-access + allowed_origins = getattr(dash_app, "_websocket_allowed_origins", []) + + @self.server.websocket(ws_path) + async def websocket_handler(): + ws = websocket + + # Validate Origin header + error = self._validate_ws_origin( + ws.headers.get("origin"), ws.headers.get("host"), allowed_origins + ) + if error: + await ws.close(code=4003, reason=error) + return + + # Call websocket_connect hooks + # pylint: disable=protected-access + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_connect"), + ws, + default_reason="Connection rejected", + ) + if rejection: + await ws.close(code=rejection[0], reason=rejection[1]) + return + + await ws.accept() + + # Track this connection for graceful shutdown + try: + ws_obj = ws._get_current_object() + self._active_websockets.add(ws_obj) + except AttributeError: + ws_obj = ws + self._active_websockets.add(ws) + + # Create janus queue for outbound messages (main loop context) + outbound_queue: janus.Queue[str] = janus.Queue() + # Track pending get_props requests with standard queue.Queue for responses + pending_get_props: Dict[str, queue.Queue] = {} + # Get thread pool executor + executor = self.get_callback_executor() + # Track pending callback futures + pending_callbacks: Dict[str, concurrent.futures.Future] = {} + + # Start sender task to drain outbound queue (sends pre-serialized text) + sender_task = asyncio.create_task(run_ws_sender(ws.send, outbound_queue)) + + try: + shutdown_event = self._ws_shutdown_event + while shutdown_event is None or not shutdown_event.is_set(): + try: + # Use timeout to periodically check shutdown event + message = await asyncio.wait_for(ws.receive_json(), timeout=1.0) + except asyncio.TimeoutError: + # Re-check shutdown event (may have been set during run()) + shutdown_event = self._ws_shutdown_event + continue + + # Call websocket_message hooks + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_message"), + ws, + message, + default_reason="Message rejected", + ) + if rejection: + await ws.close(code=rejection[0], reason=rejection[1]) + return + + msg_type = message.get("type") + + if msg_type == "callback_request": + request_id = message.get("requestId") + renderer_id = message.get("rendererId", "") + payload = message.get("payload", {}) + + # Validate that the callback is allowed to use WebSocket transport + # pylint: disable=protected-access + _validate.validate_websocket_callback_request( + payload.get("output"), + dash_app.callback_map, + dash_app._websocket_callbacks, + ) + + # Create WebSocket callback instance with outbound queue + ws_cb = DashWebsocketCallback( + pending_get_props, renderer_id, outbound_queue + ) + + # Submit callback to executor + future = run_callback_in_executor( + executor, + dash_app, + payload, + ws_cb, + QuartResponseAdapter(), + ) + + # Set up done callback to send response + future.add_done_callback( + make_callback_done_handler( + outbound_queue, + pending_callbacks, + request_id, + renderer_id, + ) + ) + pending_callbacks[request_id] = future + + elif msg_type == "get_props_response": + # Put response in waiting queue (non-blocking) + request_id = message.get("requestId") + response_queue = pending_get_props.get(request_id) + if response_queue is not None: + response_queue.put_nowait(message.get("payload")) + + elif msg_type == "heartbeat": + outbound_queue.sync_q.put_nowait('{"type": "heartbeat_ack"}') + + except asyncio.CancelledError: + pass # Server is shutting down, exit gracefully + except Exception: # pylint: disable=broad-exception-caught + pass # Other exceptions treated as disconnect + finally: + self._active_websockets.discard(ws_obj) + # Signal sender to shutdown and cancel it + outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + # Close the janus queue + outbound_queue.close() + await outbound_queue.wait_closed() + # Cancel any pending futures + for f in pending_callbacks.values(): + f.cancel() + class QuartRequestAdapter(RequestAdapter): def __init__(self) -> None: diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index f7211f44a6..52443d4104 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -3,9 +3,19 @@ This module provides abstract base classes and protocols that define the interface for different web server backends (Flask, Quart, FastAPI, etc.) to integrate with Dash. """ -from abc import ABC, abstractmethod -from typing import Any, Dict, Type, TypeVar, Generic, Protocol, TYPE_CHECKING +from __future__ import annotations +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import ( + Any, + Dict, + Type, + TypeVar, + Generic, + Protocol, + TYPE_CHECKING, +) if TYPE_CHECKING: import dash @@ -169,6 +179,7 @@ class BaseDashServer(ABC, Generic[ServerType]): config: Dict[str, Any] request_adapter: Type[RequestAdapter] response_adapter: Type[ResponseAdapter] + websocket_capability: bool = False def __init__(self, server: ServerType) -> None: """Initialize the server wrapper. @@ -178,6 +189,34 @@ def __init__(self, server: ServerType) -> None: """ super().__init__() self.server = server + self._callback_executor: ThreadPoolExecutor | None = None + + def get_callback_executor( + self, max_workers: int | None = None + ) -> ThreadPoolExecutor: + """Get or create the thread pool executor for callback execution. + + Args: + max_workers: Maximum number of worker threads. If None, uses default. + + Returns: + ThreadPoolExecutor instance for running callbacks. + """ + if self._callback_executor is None: + self._callback_executor = ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix="dash-callback-" + ) + return self._callback_executor + + def shutdown_executor(self, wait: bool = True) -> None: + """Shutdown the callback executor. + + Args: + wait: If True, wait for pending tasks to complete. + """ + if self._callback_executor is not None: + self._callback_executor.shutdown(wait=wait) + self._callback_executor = None def __call__(self, *args, **kwargs) -> Any: """Make the server wrapper callable as a WSGI/ASGI application. @@ -372,3 +411,12 @@ def setup_backend(self, dash_app: "dash.Dash"): Args: dash_app: The Dash application instance """ + + def serve_websocket_callback(self, dash_app: "dash.Dash"): + """Set up the WebSocket endpoint for callback handling. + + Override this method in backends that support WebSocket callbacks. + + Args: + dash_app: The Dash application instance + """ diff --git a/dash/backends/ws.py b/dash/backends/ws.py new file mode 100644 index 0000000000..db59fa1628 --- /dev/null +++ b/dash/backends/ws.py @@ -0,0 +1,303 @@ +"""WebSocket callback support for Dash backend implementations. + +This module provides the WebSocket callback infrastructure for real-time +bidirectional communication between Dash backends and the renderer. +""" +from __future__ import annotations + +import asyncio +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor +import inspect +import json +import queue +import traceback +import uuid +from contextvars import copy_context +from typing import Any, Callable, Dict, TYPE_CHECKING, cast + +import janus + +from dash.exceptions import PreventUpdate +from dash._utils import to_json + +if TYPE_CHECKING: + import dash + from .base_server import ResponseAdapter + + +SHUTDOWN_SIGNAL = "__shutdown__" + + +class DashWebsocketCallback: + """WebSocket callback communication via queues. + + Provides methods for real-time bidirectional communication between + the server and renderer during callback execution. + + Uses janus.Queue for outbound messages (serialized with to_json) and + queue.Queue for get_props responses, enabling thread-safe communication + between worker threads and the main event loop. + """ + + def __init__( + self, + pending_get_props: Dict[str, queue.Queue[Any]], + renderer_id: str, + outbound_queue: janus.Queue[str], + ): + """Initialize the WebSocket callback interface. + + Args: + pending_get_props: Dict to track pending get_props requests. + Values are queue.Queue instances for blocking response retrieval. + renderer_id: The renderer ID for routing messages back to the correct client + outbound_queue: janus.Queue for thread-safe outbound messaging. + """ + self._pending_get_props = pending_get_props + self._renderer_id = renderer_id + self._outbound_queue = outbound_queue + + def _queue_message(self, msg: dict) -> None: + """Serialize and queue message for sending (thread-safe, non-blocking). + + Uses to_json for proper serialization of Dash components. + """ + self._outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) + + async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: + """Send immediate prop update to the client via WebSocket. + + Queues the message for the sender coroutine to send. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to update + value: The new value to set + """ + msg = { + "type": "set_props", + "rendererId": self._renderer_id, + "payload": {"componentId": component_id, "props": {prop_name: value}}, + } + self._queue_message(msg) + + async def get_prop( + self, component_id: str, prop_name: str, timeout: float = 30.0 + ) -> Any: + """Request current prop value from the client. + + Uses queue.Queue for blocking wait in worker thread. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to retrieve + timeout: Timeout in seconds for waiting for response + + Returns: + The current value of the property from the client's state + """ + request_id = str(uuid.uuid4()) + msg = { + "type": "get_props_request", + "rendererId": self._renderer_id, + "requestId": request_id, + "payload": {"componentId": component_id, "properties": [prop_name]}, + } + + # Use standard queue.Queue for response + response_queue: queue.Queue = queue.Queue() + self._pending_get_props[request_id] = response_queue + + # Queue the outbound request via janus sync interface + self._queue_message(msg) + + # Wait for response (blocking is OK in worker thread) + try: + result = response_queue.get(timeout=timeout) + if result and prop_name in result: + return result[prop_name] + return None + except queue.Empty as exc: + raise TimeoutError( + f"Timeout waiting for {component_id}.{prop_name}" + ) from exc + finally: + self._pending_get_props.pop(request_id, None) + + +def create_ws_context( + payload: dict, + response_adapter: "ResponseAdapter", + websocket_callback: DashWebsocketCallback, +): + """Create callback context from WebSocket message. + + Args: + payload: The callback payload + response_adapter: The response adapter instance for the backend + websocket_callback: The websocket callback instance for the backend + + Returns: + AttributeDict with callback context + """ + # pylint: disable=import-outside-toplevel + from dash._utils import AttributeDict, inputs_to_dict + + g = AttributeDict({}) + g.inputs_list = payload.get("inputs", []) + g.states_list = payload.get("state", []) + g.outputs_list = payload.get("outputs", []) + g.input_values = inputs_to_dict(g.inputs_list) + g.state_values = inputs_to_dict(g.states_list) + g.triggered_inputs = [ + {"prop_id": x, "value": g.input_values.get(x)} + for x in payload.get("changedPropIds", []) + ] + g.dash_response = response_adapter + g.updated_props = {} + g.dash_websocket = websocket_callback + + return g + + +async def run_ws_sender( + send_text: Callable[[str], Any], outbound_queue: janus.Queue[str] +) -> None: + """Sender coroutine - drains queue and sends to WebSocket. + + This coroutine runs in the main event loop and handles sending + messages that are queued by worker threads via janus.Queue. + + Messages are pre-serialized strings (using to_json). + + Args: + send_text: Async function to send text data over WebSocket + outbound_queue: janus.Queue instance for receiving messages (strings) + """ + try: + while True: + msg = await outbound_queue.async_q.get() + if msg == SHUTDOWN_SIGNAL: + break + await send_text(msg) + except asyncio.CancelledError: + pass + + +def make_callback_done_handler( + outbound_queue: janus.Queue[str], + pending_callbacks: Dict[str, concurrent.futures.Future], + request_id: str, + renderer_id: str, +) -> Callable[[concurrent.futures.Future], None]: + """Create a done callback handler for executor futures. + + This factory creates a callback that sends the result back through + the WebSocket when an executor future completes. + + Args: + outbound_queue: janus.Queue for sending responses + pending_callbacks: Dict tracking pending callbacks for cleanup + request_id: The request ID for the callback response + renderer_id: The renderer ID for routing the response + + Returns: + A callback function suitable for Future.add_done_callback() + """ + + def on_done(f: concurrent.futures.Future) -> None: + try: + result = f.result() + outbound_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": result, + } + ), + ) + ) + except Exception as e: # pylint: disable=broad-exception-caught + outbound_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": { + "status": "error", + "message": str(e), + }, + } + ), + ) + ) + finally: + pending_callbacks.pop(request_id, None) + + return on_done + + +def run_callback_in_executor( + executor: ThreadPoolExecutor, + dash_app: "dash.Dash", + payload: dict, + ws_callback: DashWebsocketCallback, + response_adapter: "ResponseAdapter", +) -> concurrent.futures.Future: + """Submit callback to executor for thread pool execution. + + This function creates a callback execution context and runs it + in a separate thread. Both sync and async callbacks are supported. + + Args: + executor: ThreadPoolExecutor to submit the task to + dash_app: The Dash application instance + payload: The callback payload from WebSocket message + ws_callback: WebSocket callback instance for set_prop/get_prop + response_adapter: Response adapter for the backend + + Returns: + Future representing the pending callback execution + """ + + def execute() -> dict: + try: + cb_ctx = create_ws_context(payload, response_adapter, ws_callback) + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, payload) + args = dash_app._inputs_to_vals( # pylint: disable=protected-access + cb_ctx.inputs_list + cb_ctx.states_list + ) + + ctx = copy_context() + partial_func = ( + dash_app._execute_callback( # pylint: disable=protected-access + func, args, cb_ctx.outputs_list, cb_ctx + ) + ) + + # Run in new event loop (handles both sync and async callbacks) + def run_callback(): + result = partial_func() + if inspect.iscoroutine(result): + return asyncio.run(result) + return result + + response_data = ctx.run(run_callback) + return {"status": "ok", "data": json.loads(response_data)} + + except PreventUpdate: + return {"status": "prevent_update"} + except Exception as e: # pylint: disable=broad-exception-caught + traceback.print_exc() + return {"status": "error", "message": str(e)} + + return executor.submit(execute) diff --git a/dash/dash-renderer/init.template b/dash/dash-renderer/init.template index 463cfa02aa..a6b84d3d70 100644 --- a/dash/dash-renderer/init.template +++ b/dash/dash-renderer/init.template @@ -75,4 +75,9 @@ _js_dist = [ "namespace": "dash", "dynamic": True, }, + { + "relative_package_path": "dash-renderer/build/dash-ws-worker.js", + "namespace": "dash", + "dynamic": True, + }, ] diff --git a/dash/dash-renderer/package.json b/dash/dash-renderer/package.json index f92d22cfc5..a404fa2425 100644 --- a/dash/dash-renderer/package.json +++ b/dash/dash-renderer/package.json @@ -13,7 +13,7 @@ "build:dev": "webpack", "build:local": "renderer build local", "build": "renderer build && npm run prepublishOnly", - "postbuild": "es-check es2015 ../deps/*.js build/*.js", + "postbuild": "es-check es2015 ../deps/*.js build/dash_renderer.*.js", "test": "karma start karma.conf.js --single-run", "format": "run-s private::format.*", "lint": "run-s private::lint.* --continue-on-error" diff --git a/dash/dash-renderer/src/AppProvider.react.tsx b/dash/dash-renderer/src/AppProvider.react.tsx index 343789ca43..f9d8b06f1a 100644 --- a/dash/dash-renderer/src/AppProvider.react.tsx +++ b/dash/dash-renderer/src/AppProvider.react.tsx @@ -1,9 +1,14 @@ import PropTypes from 'prop-types'; -import React, {useState} from 'react'; +import React, {useState, useEffect} from 'react'; import {Provider} from 'react-redux'; import Store from './store'; import AppContainer from './AppContainer.react'; +import getConfigFromDOM from './config'; +import { + initializeWebSocket, + disconnectWebSocket +} from './observers/websocketObserver'; const AppProvider = ({ hooks = { @@ -16,6 +21,35 @@ const AppProvider = ({ } }: any) => { const [{store}] = useState(() => new Store()); + + // Initialize WebSocket connection if enabled or if websocket config is available + // (for per-callback websocket=True) + useEffect(() => { + const config = getConfigFromDOM(); + if ( + config.websocket?.enabled || + (config.websocket?.url && config.websocket?.worker_url) + ) { + // Add fetch config for consistency + const fullConfig = { + ...config, + fetch: { + credentials: 'same-origin', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json' + } + } + }; + initializeWebSocket(store, fullConfig); + } + + // Cleanup on unmount + return () => { + disconnectWebSocket(); + }; + }, [store]); + return ( diff --git a/dash/dash-renderer/src/actions/callbacks.ts b/dash/dash-renderer/src/actions/callbacks.ts index 206839cd07..e6604a2337 100644 --- a/dash/dash-renderer/src/actions/callbacks.ts +++ b/dash/dash-renderer/src/actions/callbacks.ts @@ -52,6 +52,11 @@ import {parsePMCId} from './patternMatching'; import {replacePMC} from './patternMatching'; import {loaded, loading} from './loading'; import {getComponentLayout} from '../wrapper/wrapping'; +import { + getWorkerClient, + isWebSocketEnabled, + isWebSocketAvailable +} from '../utils/workerClient'; export const addBlockedCallbacks = createAction( CallbackActionType.AddBlocked @@ -685,6 +690,140 @@ function handleServerside( }); } +/** + * Handle serverside callback via WebSocket connection. + * + * Uses the SharedWorker to send the callback request through the persistent + * WebSocket connection instead of HTTP POST. + */ +async function handleWebsocketCallback( + dispatch: any, + hooks: any, + config: any, + payload: ICallbackPayload, + running: any +): Promise { + if (hooks.request_pre) { + hooks.request_pre(payload); + } + + const requestTime = Date.now(); + let runningOff: any; + + if (running) { + dispatch(sideUpdate(running.running, payload)); + runningOff = running.runningOff; + } + + const workerClient = getWorkerClient(); + + try { + // Ensure WebSocket connection is established + await workerClient.ensureConnected(config); + + const response = await workerClient.sendCallback(payload); + + // Handle running off state + if (runningOff) { + dispatch(sideUpdate(runningOff, payload)); + } + + if (response.status === 'prevent_update') { + // Record timing for profiling + if (config.ui) { + const totalTime = Date.now() - requestTime; + dispatch( + updateResourceUsage({ + id: payload.output, + usage: { + __dash_server: totalTime, + __dash_client: totalTime, + __dash_upload: 0, + __dash_download: 0 + }, + status: STATUS.PREVENT_UPDATE, + result: {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + return {}; + } + + if (response.status === 'error') { + throw new Error(response.message || 'Callback error'); + } + + // Extract the callback data - structure is {multi: boolean, response: {...}} + const callbackData = response.data as CallbackResponseData; + + // Handle sideUpdate if present + if (callbackData?.sideUpdate) { + dispatch(sideUpdate(callbackData.sideUpdate, payload)); + } + + // Extract the actual outputs from the response + // Format is similar to HTTP path's finishLine function + let result: CallbackResponse; + const {multi, response: callbackResponse} = callbackData || {}; + + if (hooks.request_post) { + hooks.request_post(payload, callbackResponse); + } + + if (multi) { + result = callbackResponse as CallbackResponse; + } else { + // Single output - convert to the expected format + const {output} = payload; + const id = output.substr(0, output.lastIndexOf('.')); + result = {[id]: (callbackResponse as CallbackResponse)?.props}; + } + + // Record timing for profiling + if (config.ui) { + const totalTime = Date.now() - requestTime; + dispatch( + updateResourceUsage({ + id: payload.output, + usage: { + __dash_server: totalTime, + __dash_client: totalTime, + __dash_upload: 0, + __dash_download: 0 + }, + status: STATUS.OK, + result: result || {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + + return result || {}; + } catch (error) { + // Handle running off state on error + if (runningOff) { + dispatch(sideUpdate(runningOff, payload)); + } + + if (config.ui) { + dispatch( + updateResourceUsage({ + id: payload.output, + status: STATUS.NO_RESPONSE, + result: {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + + throw error; + } +} + function inputsToDict(inputs_list: any) { // Ported directly from _utils.py, inputs_to_dict // takes an array of inputs (some inputs may be an array) @@ -890,18 +1029,44 @@ export function executeCallback( } ); + // Use WebSocket for callbacks when: + // 1. Global WebSocket is enabled, OR + // 2. Per-callback websocket flag is set (and WebSocket is available) + // (but never for background callbacks) + const useWebSocket = + !background && + (isWebSocketEnabled(config) || + (cb.callback.websocket && + isWebSocketAvailable(config))); + for (let retry = 0; retry <= MAX_AUTH_RETRIES; retry++) { try { - let data = await handleServerside( - dispatch, - hooks, - newConfig, - payload, - background, - additionalArgs.length ? additionalArgs : undefined, - getState, - cb.callback.running - ); + let data: CallbackResponse; + + if (useWebSocket) { + // Use WebSocket path for real-time callbacks + data = await handleWebsocketCallback( + dispatch, + hooks, + newConfig, + payload, + cb.callback.running + ); + } else { + // Use traditional HTTP path + data = await handleServerside( + dispatch, + hooks, + newConfig, + payload, + background, + additionalArgs.length + ? additionalArgs + : undefined, + getState, + cb.callback.running + ); + } if (newHeaders) { dispatch(addHttpHeaders(newHeaders)); diff --git a/dash/dash-renderer/src/config.ts b/dash/dash-renderer/src/config.ts index d7f16beda8..6eb9e27b58 100644 --- a/dash/dash-renderer/src/config.ts +++ b/dash/dash-renderer/src/config.ts @@ -22,6 +22,12 @@ export type DashConfig = { serve_locally?: boolean; plotlyjs_url?: string; validate_callbacks: boolean; + websocket?: { + enabled: boolean; + url: string; + worker_url: string; + inactivity_timeout?: number; + }; csrf_token_name?: string; csrf_header_name?: string; }; diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts new file mode 100644 index 0000000000..7b75fada38 --- /dev/null +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -0,0 +1,215 @@ +/** + * Observer for handling incoming WebSocket messages (SET_PROPS, GET_PROPS_REQUEST). + */ + +/* eslint-disable no-console */ + +import {Store} from 'redux'; +import {path} from 'ramda'; + +import {IStoreState} from '../store'; +import {updateProps, notifyObservers, setPaths} from '../actions'; +import {parsePatchProps} from '../actions/patch'; +import {computePaths, getPath} from '../actions/paths'; +import { + getWorkerClient, + SetPropsPayload, + GetPropsRequestPayload +} from '../utils/workerClient'; +import {DashConfig} from '../config'; + +/** + * Parse a component ID that may be a stringified JSON object. + * This handles dict IDs like '{"index":0,"type":"output"}' that need + * to be parsed back to objects for getPath to work correctly. + */ +function parseComponentId( + componentId: string +): string | Record { + if (componentId.startsWith('{') && componentId.endsWith('}')) { + try { + return JSON.parse(componentId); + } catch { + // Not valid JSON, return as-is + return componentId; + } + } + return componentId; +} + +/** + * Initialize the WebSocket observer. + * + * Sets up handlers for: + * - SET_PROPS: Update component props when received from server + * - GET_PROPS_REQUEST: Send current prop values back to server + * + * @param store Redux store + * @param config Dash configuration + */ +export async function initializeWebSocket( + store: Store, + config: DashConfig +): Promise { + // Initialize WebSocket if: + // 1. Global websocket is enabled, OR + // 2. WebSocket config is available (for per-callback websocket=True) + const wsAvailable = !!( + config.websocket?.url && config.websocket?.worker_url + ); + if (!wsAvailable) { + return; + } + + // Check if SharedWorker is supported + if (typeof SharedWorker === 'undefined') { + console.warn( + 'SharedWorker not supported in this browser. ' + + 'WebSocket callbacks will fall back to HTTP.' + ); + return; + } + + const workerClient = getWorkerClient(); + + // Handle SET_PROPS messages + workerClient.onSetProps = (payload: SetPropsPayload) => { + const {componentId, props: rawProps} = payload; + const parsedId = parseComponentId(componentId); + const state = store.getState(); + const componentPath = getPath(state.paths, parsedId); + + if (!componentPath) { + console.warn( + `SET_PROPS: Component ${componentId} not found in layout` + ); + return; + } + + // Get old component for Patch processing and path recomputation + const oldComponent = path(componentPath, state.layout) as Record< + string, + unknown + > | null; + const oldProps = (oldComponent?.props || {}) as Record; + + // Process props to handle Patch objects + const processedProps = parsePatchProps(rawProps, oldProps); + + // Update the component props + store.dispatch( + updateProps({ + props: processedProps, + itempath: componentPath, + renderType: 'websocket' + }) as any + ); + + // Notify observers + store.dispatch( + notifyObservers({id: parsedId, props: processedProps}) as any + ); + + // Recompute paths for any new child components + if (oldComponent) { + const updatedState = store.getState(); + store.dispatch( + setPaths( + computePaths( + { + ...oldComponent, + props: {...oldProps, ...processedProps} + }, + [...componentPath], + updatedState.paths, + updatedState.paths.events + ) + ) as any + ); + } + }; + + // Handle GET_PROPS_REQUEST messages + workerClient.onGetPropsRequest = ( + requestId: string, + payload: GetPropsRequestPayload + ) => { + const {componentId, properties} = payload; + const parsedId = parseComponentId(componentId); + const state = store.getState(); + const componentPath = getPath(state.paths, parsedId); + + const result: Record = {}; + + if (componentPath) { + const componentProps = path( + [...componentPath, 'props'], + state.layout + ) as Record | undefined; + + if (componentProps) { + for (const propName of properties) { + result[propName] = componentProps[propName]; + } + } + } else { + console.warn( + `GET_PROPS_REQUEST: Component ${componentId} not found in layout` + ); + } + + // Send the response + workerClient.sendGetPropsResponse(requestId, result); + }; + + // Handle connection events + workerClient.onConnected = () => { + console.log('[Dash] WebSocket connected'); + }; + + workerClient.onDisconnected = (reason?: string) => { + console.log(`[Dash] WebSocket disconnected: ${reason}`); + }; + + workerClient.onError = (message: string, code?: string) => { + console.error(`[Dash] WebSocket error: ${message}`, code); + }; + + // Connect to the worker + const wsUrl = buildWebSocketUrl(config); + + try { + // config.websocket is guaranteed to exist due to wsAvailable check above + await workerClient.connect( + config.websocket!.worker_url, + wsUrl, + config.websocket!.inactivity_timeout + ); + } catch (error) { + console.error('[Dash] Failed to connect to WebSocket worker:', error); + } +} + +/** + * Build the WebSocket URL from config. + */ +function buildWebSocketUrl(config: DashConfig): string { + if (!config.websocket?.url) { + throw new Error('WebSocket URL not configured'); + } + + // Convert HTTP(S) URL to WS(S) + const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const host = window.location.host; + + // The config.websocket.url is a path like "/_dash-ws-callback" + return `${wsProtocol}//${host}${config.websocket.url}`; +} + +/** + * Disconnect from the WebSocket. + */ +export function disconnectWebSocket(): void { + const workerClient = getWorkerClient(); + workerClient.disconnect(); +} diff --git a/dash/dash-renderer/src/types/callbacks.ts b/dash/dash-renderer/src/types/callbacks.ts index f1e1dc382c..38a5d7d82f 100644 --- a/dash/dash-renderer/src/types/callbacks.ts +++ b/dash/dash-renderer/src/types/callbacks.ts @@ -15,6 +15,7 @@ export interface ICallbackDefinition { dynamic_creator?: boolean; running: any; no_output?: boolean; + websocket?: boolean; } export interface ICallbackProperty { diff --git a/dash/dash-renderer/src/utils/rendererId.ts b/dash/dash-renderer/src/utils/rendererId.ts new file mode 100644 index 0000000000..b9bfcfd3af --- /dev/null +++ b/dash/dash-renderer/src/utils/rendererId.ts @@ -0,0 +1,22 @@ +/** Cached renderer ID for this page instance */ +let cachedRendererId: string | null = null; + +/** + * Generate a unique renderer ID for this page instance. + * + * Each page load gets a fresh ID to avoid conflicts with stale + * connections in the SharedWorker after page reloads. + */ +export function getRendererId(): string { + if (!cachedRendererId) { + if (typeof crypto !== 'undefined' && crypto.randomUUID) { + cachedRendererId = crypto.randomUUID(); + } else { + // Fallback for older browsers + cachedRendererId = `${Date.now()}-${Math.random() + .toString(36) + .slice(2)}`; + } + } + return cachedRendererId; +} diff --git a/dash/dash-renderer/src/utils/workerClient.ts b/dash/dash-renderer/src/utils/workerClient.ts new file mode 100644 index 0000000000..f7cf4d613b --- /dev/null +++ b/dash/dash-renderer/src/utils/workerClient.ts @@ -0,0 +1,350 @@ +/** + * Client for communicating with the Dash WebSocket SharedWorker. + */ + +import {getRendererId} from './rendererId'; + +/** Message types for worker communication */ +export enum WorkerMessageType { + CONNECT = 'connect', + DISCONNECT = 'disconnect', + CALLBACK_REQUEST = 'callback_request', + GET_PROPS_RESPONSE = 'get_props_response', + CONNECTED = 'connected', + DISCONNECTED = 'disconnected', + CALLBACK_RESPONSE = 'callback_response', + SET_PROPS = 'set_props', + GET_PROPS_REQUEST = 'get_props_request', + ERROR = 'error' +} + +/** Callback response structure */ +export interface CallbackResponse { + status: 'ok' | 'prevent_update' | 'error'; + data?: Record; + message?: string; +} + +/** Set props message payload */ +export interface SetPropsPayload { + componentId: string; + props: Record; +} + +/** Get props request payload */ +export interface GetPropsRequestPayload { + componentId: string; + properties: string[]; +} + +/** Pending callback request */ +interface PendingRequest { + resolve: (value: CallbackResponse) => void; + reject: (error: Error) => void; +} + +/** + * Client for the Dash WebSocket SharedWorker. + */ +class WorkerClient { + private worker: SharedWorker | null = null; + private rendererId: string; + private pendingCallbacks: Map = new Map(); + private requestCounter = 0; + private isConnected = false; + private connectionPromise: Promise | null = null; + private connectionResolve: (() => void) | null = null; + + /** Callback when SET_PROPS message is received */ + public onSetProps: ((payload: SetPropsPayload) => void) | null = null; + + /** Callback when GET_PROPS_REQUEST message is received */ + public onGetPropsRequest: + | ((requestId: string, payload: GetPropsRequestPayload) => void) + | null = null; + + /** Callback when connection is established */ + public onConnected: (() => void) | null = null; + + /** Callback when connection is lost */ + public onDisconnected: ((reason?: string) => void) | null = null; + + /** Callback when an error occurs */ + public onError: ((message: string, code?: string) => void) | null = null; + + constructor() { + this.rendererId = getRendererId(); + } + + /** + * Initialize the worker connection. + * @param workerUrl URL to the SharedWorker script + * @param serverUrl WebSocket server URL + * @param inactivityTimeout Optional inactivity timeout in ms + */ + public async connect( + workerUrl: string, + serverUrl: string, + inactivityTimeout?: number + ): Promise { + if (this.worker) { + // Already connected + return; + } + + // Create the SharedWorker + this.worker = new SharedWorker(workerUrl, { + name: 'dash-ws-worker' + }); + + // Set up message handling + this.worker.port.onmessage = this.handleMessage.bind(this); + + // Create promise for connection + this.connectionPromise = new Promise(resolve => { + this.connectionResolve = resolve; + }); + + // Start the port + this.worker.port.start(); + + // Send connect message + this.worker.port.postMessage({ + type: WorkerMessageType.CONNECT, + rendererId: this.rendererId, + payload: { + serverUrl, + inactivityTimeout + } + }); + + // Wait for connection + await this.connectionPromise; + } + + /** + * Disconnect from the worker. + */ + public disconnect(): void { + if (this.worker) { + this.worker.port.postMessage({ + type: WorkerMessageType.DISCONNECT, + rendererId: this.rendererId + }); + this.worker.port.close(); + this.worker = null; + } + this.isConnected = false; + this.connectionPromise = null; + this.connectionResolve = null; + + // Reject any pending callbacks + for (const [, pending] of this.pendingCallbacks) { + pending.reject(new Error('Worker disconnected')); + } + this.pendingCallbacks.clear(); + } + + /** + * Ensure the worker is connected, initiating connection if needed. + * @param config The Dash config with websocket settings + */ + public async ensureConnected(config: { + websocket?: { + url?: string; + worker_url?: string; + inactivity_timeout?: number; + }; + }): Promise { + // Already connected + if (this.isConnected) { + return; + } + + // Connection in progress, wait for it + if (this.connectionPromise) { + await this.connectionPromise; + return; + } + + // Need to initiate connection + if (!config.websocket?.url || !config.websocket?.worker_url) { + throw new Error('WebSocket config not available'); + } + + if (typeof SharedWorker === 'undefined') { + throw new Error('SharedWorker not supported'); + } + + // Build WebSocket URL + const wsProtocol = + window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const host = window.location.host; + const wsUrl = `${wsProtocol}//${host}${config.websocket.url}`; + + await this.connect( + config.websocket.worker_url, + wsUrl, + config.websocket.inactivity_timeout + ); + } + + /** + * Send a callback request to the server via the worker. + * @param payload The callback payload + * @returns Promise that resolves with the callback response + */ + public async sendCallback(payload: unknown): Promise { + // Wait for initial connection if one is in progress + if (this.connectionPromise && !this.isConnected) { + await this.connectionPromise; + } + + if (!this.worker) { + throw new Error('Worker not connected'); + } + + const requestId = `${this.rendererId}-${++this.requestCounter}`; + + return new Promise((resolve, reject) => { + this.pendingCallbacks.set(requestId, {resolve, reject}); + + this.worker!.port.postMessage({ + type: WorkerMessageType.CALLBACK_REQUEST, + rendererId: this.rendererId, + requestId, + payload + }); + }); + } + + /** + * Send a get_props response back to the server. + * @param requestId The request ID from the get_props request + * @param props The property values + */ + public sendGetPropsResponse( + requestId: string, + props: Record + ): void { + if (!this.worker || !this.isConnected) { + return; + } + + this.worker.port.postMessage({ + type: WorkerMessageType.GET_PROPS_RESPONSE, + rendererId: this.rendererId, + requestId, + payload: props + }); + } + + /** + * Check if the worker is connected. + */ + public get connected(): boolean { + return this.isConnected; + } + + private handleMessage(event: MessageEvent): void { + const message = event.data; + + switch (message.type) { + case WorkerMessageType.CONNECTED: + this.isConnected = true; + if (this.connectionResolve) { + this.connectionResolve(); + this.connectionResolve = null; + } + if (this.onConnected) { + this.onConnected(); + } + break; + + case WorkerMessageType.DISCONNECTED: + this.isConnected = false; + // Reject all pending callbacks so loading states don't stay on forever + for (const [, pending] of this.pendingCallbacks) { + pending.reject(new Error('WebSocket disconnected')); + } + this.pendingCallbacks.clear(); + if (this.onDisconnected) { + this.onDisconnected(message.payload?.reason); + } + break; + + case WorkerMessageType.CALLBACK_RESPONSE: { + const requestId = message.requestId; + const pending = this.pendingCallbacks.get(requestId); + if (pending) { + this.pendingCallbacks.delete(requestId); + pending.resolve(message.payload); + } + break; + } + + case WorkerMessageType.SET_PROPS: + if (this.onSetProps) { + this.onSetProps(message.payload); + } + break; + + case WorkerMessageType.GET_PROPS_REQUEST: + if (this.onGetPropsRequest) { + this.onGetPropsRequest(message.requestId, message.payload); + } + break; + + case WorkerMessageType.ERROR: + if (this.onError) { + this.onError( + message.payload?.message || 'Unknown error', + message.payload?.code + ); + } + break; + } + } +} + +// Singleton instance +let workerClientInstance: WorkerClient | null = null; + +/** + * Get the singleton WorkerClient instance. + */ +export function getWorkerClient(): WorkerClient { + if (!workerClientInstance) { + workerClientInstance = new WorkerClient(); + } + return workerClientInstance; +} + +/** + * Check if WebSocket callbacks are globally enabled and supported. + * @param config The Dash config + */ +export function isWebSocketEnabled(config: { + websocket?: {enabled: boolean}; +}): boolean { + return !!(config.websocket?.enabled && typeof SharedWorker !== 'undefined'); +} + +/** + * Check if WebSocket infrastructure is available (for per-callback websocket). + * @param config The Dash config + */ +export function isWebSocketAvailable(config: { + websocket?: { + enabled?: boolean; + url?: string; + worker_url?: string; + inactivity_timeout?: number; + }; +}): boolean { + return !!( + config.websocket?.url && + config.websocket?.worker_url && + typeof SharedWorker !== 'undefined' + ); +} diff --git a/dash/dash-renderer/webpack.base.config.js b/dash/dash-renderer/webpack.base.config.js index ed95239f7d..e8a9d14596 100644 --- a/dash/dash-renderer/webpack.base.config.js +++ b/dash/dash-renderer/webpack.base.config.js @@ -72,6 +72,31 @@ const rendererOptions = { ...defaults }; +// WebSocket Worker configuration +const workerOptions = { + mode: 'production', + entry: { + 'dash-ws-worker': '../../@plotly/dash-websocket-worker/src/worker.ts', + }, + output: { + path: path.resolve(__dirname, "build"), + filename: '[name].js', + }, + target: 'webworker', + module: { + rules: [ + { + test: /\.ts$/, + exclude: /node_modules/, + use: ['ts-loader'], + }, + ] + }, + resolve: { + extensions: ['.ts', '.js'] + } +}; + module.exports = options => [ R.mergeAll([ options, @@ -109,5 +134,7 @@ module.exports = options => [ ] ), } - ]) + ]), + // WebSocket Worker build + workerOptions ]; diff --git a/dash/dash.py b/dash/dash.py index 78d809d921..ddae7896b4 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -483,6 +483,9 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches health_endpoint: Optional[str] = None, csrf_token_name: str = "_csrf_token", csrf_header_name: str = "X-CSRFToken", + websocket_callbacks: Optional[bool] = False, + websocket_allowed_origins: Optional[List[str]] = None, + websocket_inactivity_timeout: Optional[int] = 300000, **obsolete, ): @@ -639,6 +642,9 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._assets_files: list = [] self._background_manager = background_callback_manager + self._websocket_callbacks = websocket_callbacks + self._websocket_allowed_origins = websocket_allowed_origins or [] + self._websocket_inactivity_timeout = websocket_inactivity_timeout self.logger = logging.getLogger(__name__) @@ -781,6 +787,12 @@ def _setup_routes(self): ) if self.config.health_endpoint is not None: self._add_url(self.config.health_endpoint, self.serve_health) + + # Set up WebSocket callback route if backend supports it + # This enables both global websocket_callbacks and per-callback websocket=True + if self.backend.websocket_capability: + self.backend.serve_websocket_callback(self) + self.backend.setup_index(self) self.backend.setup_catchall(self) @@ -962,6 +974,16 @@ def _config(self): custom_dev_tools.append({**hook_dev_tools, "props": props}) config["dev_tools"] = custom_dev_tools + # Add websocket config if backend supports it + # This enables both global websocket_callbacks and per-callback websocket=True + if self.backend.websocket_capability: + config["websocket"] = { + "enabled": bool(self._websocket_callbacks), + "url": self.config.requests_pathname_prefix + "_dash-ws-callback", + "worker_url": self._get_worker_url(), + "inactivity_timeout": self._websocket_inactivity_timeout, + } + return config def serve_reload_hash(self): @@ -989,6 +1011,33 @@ def serve_health(self): """ return self.backend.make_response("OK", status=200, mimetype="text/plain") + def _get_worker_url(self) -> str: + """Get the URL for the WebSocket worker script. + + Returns: + The fingerprinted URL for the worker script served via component suites. + """ + relative_path = "dash-renderer/build/dash-ws-worker.js" + namespace = "dash" + + # Register the path so it can be served + self.registered_paths[namespace].add(relative_path) + + # Build fingerprinted URL (same pattern as _collect_and_register_resources) + module_path = os.path.join( + os.path.dirname(sys.modules[namespace].__file__), # type: ignore + relative_path, + ) + + # Use a fallback if the file doesn't exist yet (during development) + try: + modified = int(os.stat(module_path).st_mtime) + except FileNotFoundError: + modified = 0 + + fingerprint = build_fingerprint(relative_path, __version__, modified) + return f"{self.config.requests_pathname_prefix}_dash-component-suites/{namespace}/{fingerprint}" + def get_dist(self, libraries: Sequence[str]) -> list: dists = [] for dist_type in ("_js_dist", "_css_dist"): diff --git a/dash/exceptions.py b/dash/exceptions.py index 019f0d2726..40e882c409 100644 --- a/dash/exceptions.py +++ b/dash/exceptions.py @@ -113,3 +113,7 @@ class HookError(DashException): class AppNotFoundError(DashException): pass + + +class WebSocketCallbackError(CallbackException): + pass diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py index 6e6cc8b810..51a938f72f 100644 --- a/dash/testing/application_runners.py +++ b/dash/testing/application_runners.py @@ -147,6 +147,7 @@ class ThreadedRunner(BaseDashRunner): def __init__(self, keep_open=False, stop_timeout=3): super().__init__(keep_open=keep_open, stop_timeout=stop_timeout) self.thread = None + self._app = None # Store app reference for graceful shutdown def running_and_accessible(self, url): if self.thread.is_alive(): # type: ignore[reportOptionalMemberAccess] @@ -156,6 +157,7 @@ def running_and_accessible(self, url): # pylint: disable=arguments-differ def start(self, app, start_timeout=3, **kwargs): """Start the app server in threading flavor.""" + self._app = app # Store app reference for graceful shutdown def run(): app.scripts.config.serve_locally = True @@ -175,7 +177,10 @@ def run(): # FastAPI support if module.startswith("fastapi"): app.run(**options) - # Dash/Flask/Quart fallback + # Quart support (ASGI - runs its own async event loop) + elif module.startswith("quart"): + app.run(**options) + # Flask fallback (WSGI - needs threaded mode) else: app.run(threaded=True, **options) except SystemExit: @@ -213,9 +218,17 @@ def run(): raise DashAppLoadingError("threaded server failed to start") def stop(self): - self.thread.kill() # type: ignore[reportOptionalMemberAccess] - self.thread.join() # type: ignore[reportOptionalMemberAccess] - wait.until_not(self.thread.is_alive, self.stop_timeout) # type: ignore[reportOptionalMemberAccess] + # For FastAPI apps with uvicorn, use graceful shutdown + if self._app and hasattr(self._app, "_uvicorn_server"): + server = self._app._uvicorn_server # pylint: disable=protected-access + server.should_exit = True + self.thread.join(timeout=self.stop_timeout) # type: ignore[reportOptionalMemberAccess] + else: + # Fall back to killing threads for Flask/other backends + self.thread.kill() # type: ignore[reportOptionalMemberAccess] + self.thread.join() # type: ignore[reportOptionalMemberAccess] + wait.until_not(self.thread.is_alive, self.stop_timeout) # type: ignore[reportOptionalMemberAccess] + self._app = None self.started = False @@ -239,7 +252,10 @@ def target(): # FastAPI support if module.startswith("fastapi"): app.run(**options) - # Dash/Flask/Quart fallback + # Quart support (ASGI - runs its own async event loop) + elif module.startswith("quart"): + app.run(**options) + # Flask fallback (WSGI - needs threaded mode) else: app.run(threaded=True, **options) except SystemExit: diff --git a/dash/version.py b/dash/version.py index 25b76de3c3..6af77684e6 100644 --- a/dash/version.py +++ b/dash/version.py @@ -1 +1 @@ -__version__ = "4.2.0rc0" +__version__ = "4.2.0rc2" diff --git a/r19.py b/r19.py new file mode 100644 index 0000000000..815d2a5066 --- /dev/null +++ b/r19.py @@ -0,0 +1,222 @@ +""" +React 19 test app with most Dash components. +Run with: python r19.py +""" + +import os +os.environ["REACT_VERSION"] = "19.2.0" + +from dash import Dash, html, dcc, dash_table, callback, Input, Output +import plotly.express as px +import pandas as pd + +# Sample data +df = pd.DataFrame({ + "Fruit": ["Apples", "Oranges", "Bananas", "Grapes", "Strawberries"], + "Amount": [4, 2, 5, 3, 6], + "City": ["NYC", "LA", "Chicago", "Houston", "Phoenix"] +}) + +app = Dash(__name__) + +app.layout = html.Div([ + html.H1("React 19 Component Test"), + html.P(f"Running React version: {os.environ.get('REACT_VERSION')}"), + + html.Hr(), + html.H2("Core HTML Components"), + html.Div([ + html.Button("Click Me", id="button", n_clicks=0), + html.Span(" Clicks: ", style={"marginLeft": "10px"}), + html.Span(id="click-output", children="0"), + ]), + + html.Hr(), + html.H2("Input Components"), + html.Div([ + html.Label("Text Input:"), + dcc.Input(id="text-input", type="text", placeholder="Type something...", debounce=True), + html.Div(id="text-output"), + ], style={"marginBottom": "20px"}), + + html.Div([ + html.Label("Dropdown:"), + dcc.Dropdown( + id="dropdown", + options=[{"label": f, "value": f} for f in df["Fruit"]], + value="Apples", + clearable=True, + ), + html.Div(id="dropdown-output"), + ], style={"marginBottom": "20px", "width": "300px"}), + + html.Div([ + html.Label("Multi-Select Dropdown:"), + dcc.Dropdown( + id="multi-dropdown", + options=[{"label": f, "value": f} for f in df["Fruit"]], + value=["Apples", "Oranges"], + multi=True, + ), + ], style={"marginBottom": "20px", "width": "300px"}), + + html.Div([ + html.Label("Slider:"), + dcc.Slider(id="slider", min=0, max=10, step=1, value=5, marks={i: str(i) for i in range(11)}), + html.Div(id="slider-output"), + ], style={"marginBottom": "20px", "width": "400px"}), + + html.Div([ + html.Label("Range Slider:"), + dcc.RangeSlider(id="range-slider", min=0, max=100, step=10, value=[20, 80]), + ], style={"marginBottom": "20px", "width": "400px"}), + + html.Div([ + html.Label("Radio Items:"), + dcc.RadioItems( + id="radio", + options=[{"label": c, "value": c} for c in df["City"]], + value="NYC", + inline=True, + ), + ], style={"marginBottom": "20px"}), + + html.Div([ + html.Label("Checklist:"), + dcc.Checklist( + id="checklist", + options=[{"label": c, "value": c} for c in df["City"]], + value=["NYC", "LA"], + inline=True, + ), + ], style={"marginBottom": "20px"}), + + html.Div([ + html.Label("Date Picker:"), + dcc.DatePickerSingle(id="date-picker", date="2024-01-15"), + ], style={"marginBottom": "20px"}), + + html.Div([ + html.Label("Date Range Picker:"), + dcc.DatePickerRange( + id="date-range", + start_date="2024-01-01", + end_date="2024-12-31", + ), + ], style={"marginBottom": "20px"}), + + html.Div([ + html.Label("Textarea:"), + dcc.Textarea(id="textarea", value="Some text here...", style={"width": "300px", "height": "100px"}), + ], style={"marginBottom": "20px"}), + + html.Hr(), + html.H2("Graph Component"), + dcc.Graph( + id="graph", + figure=px.bar(df, x="Fruit", y="Amount", color="City", title="Fruit Amounts by City") + ), + + html.Hr(), + html.H2("DataTable"), + dash_table.DataTable( + id="table", + columns=[{"name": c, "id": c} for c in df.columns], + data=df.to_dict("records"), + editable=True, + filter_action="native", + sort_action="native", + row_selectable="multi", + page_size=10, + ), + + html.Hr(), + html.H2("Tabs"), + dcc.Tabs(id="tabs", value="tab-1", children=[ + dcc.Tab(label="Tab 1", value="tab-1", children=[ + html.Div("Content for Tab 1", style={"padding": "20px"}) + ]), + dcc.Tab(label="Tab 2", value="tab-2", children=[ + html.Div("Content for Tab 2", style={"padding": "20px"}) + ]), + ]), + + html.Hr(), + html.H2("Loading Component"), + dcc.Loading( + id="loading", + type="circle", + children=html.Div(id="loading-output", children="Content loaded!") + ), + + html.Hr(), + html.H2("Markdown"), + dcc.Markdown(""" + ### This is Markdown + + - Item 1 + - Item 2 + - **Bold text** + - *Italic text* + + ```python + def hello(): + return "Hello, React 19!" + ``` + """), + + html.Hr(), + html.H2("Store & Interval"), + dcc.Store(id="store", data={"count": 0}), + dcc.Interval(id="interval", interval=5000, n_intervals=0, disabled=True), + html.Div(id="interval-output", children="Interval disabled"), + + html.Hr(), + html.H2("Clipboard"), + dcc.Clipboard(id="clipboard", target_id="text-input", style={"fontSize": "20px"}), + + html.Hr(), + html.H2("Tooltip"), + html.Div([ + html.Span("Hover over the graph points to see tooltips", style={"fontStyle": "italic"}), + ]), + + html.Br(), + html.Br(), +], style={"padding": "20px", "maxWidth": "800px", "margin": "0 auto"}) + + +@callback( + Output("click-output", "children"), + Input("button", "n_clicks") +) +def update_clicks(n): + return str(n) + + +@callback( + Output("text-output", "children"), + Input("text-input", "value") +) +def update_text(value): + return f"You typed: {value}" if value else "" + + +@callback( + Output("dropdown-output", "children"), + Input("dropdown", "value") +) +def update_dropdown(value): + return f"Selected: {value}" if value else "Nothing selected" + + +@callback( + Output("slider-output", "children"), + Input("slider", "value") +) +def update_slider(value): + return f"Slider value: {value}" + + +if __name__ == "__main__": + app.run(debug=True, port=8050) diff --git a/requirements/install.txt b/requirements/install.txt index df0e1299e3..284f3a5031 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -7,3 +7,4 @@ requests retrying nest-asyncio setuptools +janus>=1.0.0 diff --git a/tests/websocket/__init__.py b/tests/websocket/__init__.py new file mode 100644 index 0000000000..1116026afc --- /dev/null +++ b/tests/websocket/__init__.py @@ -0,0 +1 @@ +# WebSocket callback tests diff --git a/tests/websocket/conftest.py b/tests/websocket/conftest.py new file mode 100644 index 0000000000..d72fcd04dc --- /dev/null +++ b/tests/websocket/conftest.py @@ -0,0 +1,12 @@ +import pytest +from dash import hooks + + +@pytest.fixture +def ws_hook_cleanup(): + """Clean up WebSocket hooks after each test.""" + yield + hooks._ns["websocket_connect"] = [] + hooks._ns["websocket_message"] = [] + hooks._finals.pop("websocket_connect", None) + hooks._finals.pop("websocket_message", None) diff --git a/tests/websocket/test_ws_basic.py b/tests/websocket/test_ws_basic.py new file mode 100644 index 0000000000..935d633339 --- /dev/null +++ b/tests/websocket/test_ws_basic.py @@ -0,0 +1,254 @@ +""" +Basic WebSocket callback tests. + +Tests: +- Per-callback websocket (websocket=True) +- Global websocket callbacks (websocket_callbacks=True) +- Mixed HTTP and WebSocket callbacks +""" + +from dash import Dash, html, dcc, Input, Output, State, ctx + + +def test_ws001_per_callback_websocket(dash_duo): + """Test single callback with websocket=True on FastAPI backend.""" + app = Dash(__name__, backend="fastapi") + + app.layout = html.Div( + [ + html.H1("Per-Callback WebSocket Test"), + dcc.Input(id="ws-input", type="text", placeholder="Type here..."), + html.Div(id="ws-output"), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"WS: {value or ''}" + + dash_duo.start_server(app) + + # Test initial state (trailing space is trimmed by HTML rendering) + dash_duo.wait_for_text_to_equal("#ws-output", "WS:") + + # Type into the input and verify callback executes + input_elem = dash_duo.find_element("#ws-input") + input_elem.send_keys("hello") + + dash_duo.wait_for_text_to_equal("#ws-output", "WS: hello") + assert dash_duo.get_logs() == [] + + +def test_ws002_global_websocket_callbacks(dash_duo): + """Test global websocket_callbacks=True enables WebSocket for all callbacks.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + ) + + app.layout = html.Div( + [ + html.Button("Click me", id="btn", n_clicks=0), + html.Div(id="output"), + dcc.Input(id="input", type="text"), + html.Div(id="input-output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks} times" + + @app.callback(Output("input-output", "children"), Input("input", "value")) + def on_input(value): + return f"Input: {value or ''}" + + dash_duo.start_server(app) + + # Test button callback + dash_duo.wait_for_text_to_equal("#output", "Clicked 0 times") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1 times") + + # Test input callback + dash_duo.find_element("#input").send_keys("test") + dash_duo.wait_for_text_to_equal("#input-output", "Input: test") + + assert dash_duo.get_logs() == [] + + +def test_ws003_mixed_http_and_websocket(dash_duo): + """Test mixing WebSocket and HTTP callbacks in the same app.""" + app = Dash(__name__, backend="fastapi") + + app.layout = html.Div( + [ + # WebSocket callback section + html.Div( + [ + dcc.Input(id="ws-input", type="text"), + html.Div(id="ws-output"), + ] + ), + # HTTP callback section (default) + html.Div( + [ + dcc.Input(id="http-input", type="text"), + html.Div(id="http-output"), + ] + ), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"[WebSocket] {value or ''}" + + @app.callback(Output("http-output", "children"), Input("http-input", "value")) + def http_callback(value): + return f"[HTTP] {value or ''}" + + dash_duo.start_server(app) + + # Test WebSocket callback + dash_duo.find_element("#ws-input").send_keys("ws-test") + dash_duo.wait_for_text_to_equal("#ws-output", "[WebSocket] ws-test") + + # Test HTTP callback + dash_duo.find_element("#http-input").send_keys("http-test") + dash_duo.wait_for_text_to_equal("#http-output", "[HTTP] http-test") + + assert dash_duo.get_logs() == [] + + +def test_ws004_websocket_with_state(dash_duo): + """Test WebSocket callback with State inputs.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + dcc.Input(id="state-input", type="text", value="initial"), + html.Button("Submit", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn", "n_clicks"), + State("state-input", "value"), + ) + def on_click(n_clicks, state_value): + if not n_clicks: + return "Click to submit" + return f"Submitted: {state_value}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to submit") + + # Update state input + state_input = dash_duo.find_element("#state-input") + dash_duo.clear_input(state_input) + state_input.send_keys("new value") + + # Click button to trigger callback + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Submitted: new value") + + assert dash_duo.get_logs() == [] + + +def test_ws005_websocket_context_available(dash_duo): + """Test that WebSocket context is available in WebSocket callbacks.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Check context", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def check_context(n_clicks): + if not n_clicks: + return "Click to check" + ws = ctx.get_websocket + if ws is not None: + return "WebSocket context available" + return "No WebSocket context" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to check") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "WebSocket context available") + + assert dash_duo.get_logs() == [] + + +def test_ws006_websocket_multiple_outputs(dash_duo): + """Test WebSocket callback with multiple outputs.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update", id="btn"), + html.Div(id="output1"), + html.Div(id="output2"), + html.Div(id="output3"), + ] + ) + + @app.callback( + Output("output1", "children"), + Output("output2", "children"), + Output("output3", "children"), + Input("btn", "n_clicks"), + ) + def multi_output(n_clicks): + n = n_clicks or 0 + return f"First: {n}", f"Second: {n * 2}", f"Third: {n * 3}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output1", "First: 0") + dash_duo.wait_for_text_to_equal("#output2", "Second: 0") + dash_duo.wait_for_text_to_equal("#output3", "Third: 0") + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output1", "First: 1") + dash_duo.wait_for_text_to_equal("#output2", "Second: 2") + dash_duo.wait_for_text_to_equal("#output3", "Third: 3") + + assert dash_duo.get_logs() == [] + + +def test_ws007_websocket_slider_callback(dash_duo): + """Test WebSocket callback with dcc.Slider component.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + dcc.Slider(id="slider", min=0, max=100, value=50, step=10), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("slider", "value")) + def update_output(value): + return f"Slider value: {value}" + + dash_duo.start_server(app) + + # Initial callback should work via WebSocket + dash_duo.wait_for_text_to_equal("#output", "Slider value: 50") + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_hooks.py b/tests/websocket/test_ws_hooks.py new file mode 100644 index 0000000000..db0a166efd --- /dev/null +++ b/tests/websocket/test_ws_hooks.py @@ -0,0 +1,285 @@ +""" +WebSocket hooks tests. + +Tests: +- websocket_connect hook - accept/reject connections +- websocket_message hook - accept/reject messages +- Custom close codes and reasons +""" + +from dash import Dash, html, Input, Output, hooks + + +def test_ws010_connect_hook_accept(dash_duo, ws_hook_cleanup): + """Test websocket_connect hook that accepts all connections.""" + connection_count = {"value": 0} + + @hooks.websocket_connect() + def allow_all(websocket): + connection_count["value"] += 1 + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Hook should have been called at least once for connection + assert connection_count["value"] >= 1 + assert dash_duo.get_logs() == [] + + +def test_ws011_connect_hook_reject_false(dash_duo, ws_hook_cleanup): + """Test websocket_connect hook that rejects with False. + + When WebSocket connection is rejected, callbacks won't work since + websocket_callbacks=True requires WebSocket transport. + """ + + @hooks.websocket_connect() + def reject_all(websocket): + return False + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # WebSocket rejected - callbacks won't fire, output stays initial + import time + + time.sleep(1) # Give time for potential callback + assert dash_duo.find_element("#output").text == "initial" + + dash_duo.find_element("#btn").click() + time.sleep(1) + # Still initial since WebSocket was rejected + assert dash_duo.find_element("#output").text == "initial" + + +def test_ws012_connect_hook_reject_tuple(dash_duo, ws_hook_cleanup): + """Test websocket_connect hook that rejects with custom code/reason. + + When WebSocket connection is rejected, callbacks won't work since + websocket_callbacks=True requires WebSocket transport. + """ + + @hooks.websocket_connect() + def reject_with_reason(websocket): + return (4001, "Connection not allowed") + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # WebSocket rejected - callbacks won't fire, output stays initial + import time + + time.sleep(1) + assert dash_duo.find_element("#output").text == "initial" + + dash_duo.find_element("#btn").click() + time.sleep(1) + assert dash_duo.find_element("#output").text == "initial" + + +def test_ws013_message_hook_accept(dash_duo, ws_hook_cleanup): + """Test websocket_message hook that accepts all messages.""" + message_count = {"value": 0} + + @hooks.websocket_message() + def allow_all_messages(websocket, message): + message_count["value"] += 1 + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Message hook should have been called + assert message_count["value"] >= 1 + assert dash_duo.get_logs() == [] + + +def test_ws014_message_hook_reject(dash_duo, ws_hook_cleanup): + """Test websocket_message hook that rejects specific messages.""" + reject_clicks = {"should_reject": False} + + @hooks.websocket_message() + def conditional_reject(websocket, message): + if reject_clicks["should_reject"]: + return (4010, "Message rejected") + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # First click should work + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws015_async_connect_hook(dash_duo, ws_hook_cleanup): + """Test async websocket_connect hook.""" + import asyncio + + @hooks.websocket_connect() + async def async_validate(websocket): + await asyncio.sleep(0.01) # Simulate async validation + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws016_async_message_hook(dash_duo, ws_hook_cleanup): + """Test async websocket_message hook.""" + import asyncio + + @hooks.websocket_message() + async def async_validate_message(websocket, message): + await asyncio.sleep(0.01) # Simulate async validation + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws017_multiple_connect_hooks(dash_duo, ws_hook_cleanup): + """Test multiple websocket_connect hooks with priorities.""" + hook_order = [] + + @hooks.websocket_connect(priority=1) + def first_hook(websocket): + hook_order.append("first") + return True + + @hooks.websocket_connect(priority=2) + def second_hook(websocket): + hook_order.append("second") + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Both hooks should have been called + assert "first" in hook_order + assert "second" in hook_order + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_inactivity.py b/tests/websocket/test_ws_inactivity.py new file mode 100644 index 0000000000..8dd95e1094 --- /dev/null +++ b/tests/websocket/test_ws_inactivity.py @@ -0,0 +1,194 @@ +""" +WebSocket inactivity timeout tests. + +Tests: +- Connection closes after inactivity period +- Activity resets the timer +- Heartbeats don't count as activity +- Auto-reconnect when callback fires after timeout +""" + +import time +from dash import Dash, html, Input, Output + + +def test_ws020_inactivity_timeout_closes(dash_duo): + """Test that WebSocket connection closes after inactivity timeout.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=3000, # 3 seconds for testing + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Trigger callback to establish connection + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Wait for inactivity timeout + time.sleep(4) + + # Click again - should auto-reconnect and work + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 2") + + +def test_ws021_activity_resets_timer(dash_duo): + """Test that callback activity resets the inactivity timer.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=4000, # 4 seconds + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + + # Click every 2 seconds - should keep connection alive + for i in range(1, 4): + time.sleep(2) + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", f"Clicked {i}") + + # All clicks should work without disconnection + assert dash_duo.get_logs() == [] + + +def test_ws022_quick_successive_callbacks(dash_duo): + """Test rapid successive callbacks work correctly.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=5000, + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("0", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return str(n_clicks or 0) + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "0") + + # Rapid clicks + for _ in range(5): + dash_duo.find_element("#btn").click() + time.sleep(0.1) + + dash_duo.wait_for_text_to_equal("#output", "5") + assert dash_duo.get_logs() == [] + + +def test_ws023_auto_reconnect_after_timeout(dash_duo): + """Test auto-reconnect when callback fires after inactivity timeout.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=2000, # 2 seconds + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Initial callback + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Wait for timeout to expire + time.sleep(3) + + # Click again - should auto-reconnect + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 2") + + # And keep working + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 3") + + assert dash_duo.get_logs() == [] + + +def test_ws024_long_callback_doesnt_timeout(dash_duo): + """Test that long-running callbacks don't cause timeout during execution.""" + import asyncio + + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=3000, # 3 seconds + ) + + app.layout = html.Div( + [ + html.Button("Start Long Task", id="btn"), + html.Div("ready", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + async def long_task(n_clicks): + if not n_clicks: + return "ready" + # Simulate long task (longer than inactivity timeout) + await asyncio.sleep(2) + return f"Completed task {n_clicks}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "ready") + + # Start long task + dash_duo.find_element("#btn").click() + + # Should complete despite being longer than half the timeout + dash_duo.wait_for_text_to_equal("#output", "Completed task 1", timeout=10) + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_origin.py b/tests/websocket/test_ws_origin.py new file mode 100644 index 0000000000..c6235613a5 --- /dev/null +++ b/tests/websocket/test_ws_origin.py @@ -0,0 +1,154 @@ +""" +WebSocket origin validation tests. + +Tests: +- Same-origin connections allowed by default +- Cross-origin rejected unless explicitly allowed +- websocket_allowed_origins configuration +""" + +from dash import Dash, html, Input, Output + + +def test_ws040_same_origin_allowed(dash_duo): + """Test that same-origin WebSocket connections work by default.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Same-origin request should work + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws041_websocket_allowed_origins_empty(dash_duo): + """Test with empty websocket_allowed_origins (only same-origin).""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_allowed_origins=[], # Only same-origin + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Same-origin should still work + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws042_websocket_allowed_origins_wildcard(dash_duo): + """Test with wildcard in websocket_allowed_origins.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_allowed_origins=["*"], # Allow all origins + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws043_websocket_allowed_origins_specific(dash_duo): + """Test with specific origins in websocket_allowed_origins.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Should work since we're running on localhost + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws044_origin_with_per_callback_websocket(dash_duo): + """Test origin validation with per-callback websocket=True.""" + app = Dash( + __name__, + backend="fastapi", + websocket_allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback( + Output("output", "children"), Input("btn", "n_clicks"), websocket=True + ) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_patch.py b/tests/websocket/test_ws_patch.py new file mode 100644 index 0000000000..f278b5f26e --- /dev/null +++ b/tests/websocket/test_ws_patch.py @@ -0,0 +1,42 @@ +""" +WebSocket set_props with Patch object test. + +Verifies that set_props works with Patch objects in websocket callbacks. +""" + +from dash import Dash, html, Input, Output, set_props, Patch +from dash.exceptions import PreventUpdate + + +def test_ws037_set_props_with_patch(dash_duo): + """Test set_props with Patch object in websocket callback.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Patch", id="btn"), + html.Div("initial", id="output"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), Input("btn", "n_clicks"), websocket=True + ) + def patch_append(n): + if not n: + raise PreventUpdate + + p = Patch() + p += f" + click {n}" + + set_props("output", {"children": p}) + return f"Appended {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output", "initial + click 1", timeout=10) + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_props.py b/tests/websocket/test_ws_props.py new file mode 100644 index 0000000000..e800668ae8 --- /dev/null +++ b/tests/websocket/test_ws_props.py @@ -0,0 +1,471 @@ +""" +WebSocket set_props and get_props tests. + +Tests: +- set_props streaming during long-running callback +- get_prop reads current component value +- async set_prop method +- set_props with Patch objects (bug fix for component property updates) +- set_props with pattern-matching components triggering MATCH callbacks +""" + +import asyncio +from dash import Dash, html, Input, Output, State, set_props, MATCH +from dash.exceptions import PreventUpdate + + +def test_ws030_set_props_streaming(dash_duo): + """Test that set_props streams updates during callback execution.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Start", id="btn"), + html.Div("0%", id="progress"), + html.Div("waiting", id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def long_task(n): + if not n: + raise PreventUpdate + + for i in range(1, 6): + set_props("progress", {"children": f"{i * 20}%"}) + await asyncio.sleep(0.1) + + return "Done" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#progress", "0%") + dash_duo.wait_for_text_to_equal("#result", "waiting") + + dash_duo.find_element("#btn").click() + + # Should see progress updates and final result + dash_duo.wait_for_text_to_equal("#result", "Done", timeout=10) + # Final progress should be 100% + dash_duo.wait_for_text_to_equal("#progress", "100%") + + assert dash_duo.get_logs() == [] + + +def test_ws031_set_props_multiple_components(dash_duo): + """Test set_props updating multiple components during callback.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update All", id="btn"), + html.Div("A: initial", id="output-a"), + html.Div("B: initial", id="output-b"), + html.Div("C: initial", id="output-c"), + html.Div("result", id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_all(n): + if not n: + raise PreventUpdate + + set_props("output-a", {"children": f"A: updated {n}"}) + await asyncio.sleep(0.05) + set_props("output-b", {"children": f"B: updated {n}"}) + await asyncio.sleep(0.05) + set_props("output-c", {"children": f"C: updated {n}"}) + + return f"All updated {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output-a", "A: updated 1", timeout=10) + dash_duo.wait_for_text_to_equal("#output-b", "B: updated 1") + dash_duo.wait_for_text_to_equal("#output-c", "C: updated 1") + dash_duo.wait_for_text_to_equal("#result", "All updated 1") + + assert dash_duo.get_logs() == [] + + +def test_ws032_set_props_with_complex_values(dash_duo): + """Test set_props with various value types.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Test Values", id="btn"), + html.Div(id="text-output"), + html.Div(id="number-output"), + html.Div(id="list-output"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def test_values(n): + if not n: + raise PreventUpdate + + # String + set_props("text-output", {"children": "Hello World"}) + await asyncio.sleep(0.02) + + # Number as string + set_props("number-output", {"children": str(42)}) + await asyncio.sleep(0.02) + + # List of strings + set_props("list-output", {"children": ["Item 1", " - ", "Item 2"]}) + + return "Values set" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#text-output", "Hello World", timeout=10) + dash_duo.wait_for_text_to_equal("#number-output", "42") + dash_duo.wait_for_text_to_equal("#list-output", "Item 1 - Item 2") + dash_duo.wait_for_text_to_equal("#result", "Values set") + + assert dash_duo.get_logs() == [] + + +def test_ws033_set_props_sync_callback(dash_duo): + """Test set_props in synchronous callback with WebSocket.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Sync Update", id="btn"), + html.Div("before", id="side-effect"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + def sync_update(n): + if not n: + raise PreventUpdate + + # set_props should work in sync callback too + set_props("side-effect", {"children": f"Side effect {n}"}) + return f"Result {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#result", "Result 1", timeout=10) + dash_duo.wait_for_text_to_equal("#side-effect", "Side effect 1") + + assert dash_duo.get_logs() == [] + + +def test_ws034_get_prop_reads_value(dash_duo): + """Test that get_prop can read current component values.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Div("Source Value", id="source"), + html.Button("Read", id="btn"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def read_prop(n): + if not n: + raise PreventUpdate + + from dash import ctx + + ws = ctx.get_websocket + if ws: + value = await ws.get_prop("source", "children") + return f"Read: {value}" + return "No WebSocket" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#result", "Read: Source Value", timeout=10) + + assert dash_duo.get_logs() == [] + + +def test_ws035_websocket_set_prop_method(dash_duo): + """Test using ws.set_prop() method directly.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Set via WS", id="btn"), + html.Div("original", id="target"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def set_via_ws(n): + if not n: + raise PreventUpdate + + from dash import ctx + + ws = ctx.get_websocket + if ws: + await ws.set_prop("target", "children", f"Set via WebSocket {n}") + return "Set complete" + return "No WebSocket" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#target", "Set via WebSocket 1", timeout=10) + dash_duo.wait_for_text_to_equal("#result", "Set complete") + + assert dash_duo.get_logs() == [] + + +def test_ws036_set_props_dict_component_id(dash_duo): + """Test set_props with dict component ID (pattern matching).""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update", id="btn"), + html.Div("initial", id={"type": "output", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_with_dict_id(n): + if not n: + raise PreventUpdate + + set_props({"type": "output", "index": 0}, {"children": f"Updated {n}"}) + return f"Done {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + # Use attribute selector for the dict ID + dash_duo.wait_for_text_to_equal( + '[id=\'{"index":0,"type":"output"}\']', "Updated 1", timeout=10 + ) + dash_duo.wait_for_text_to_equal("#result", "Done 1") + + assert dash_duo.get_logs() == [] + + +def test_ws045_set_props_component_prop_children(dash_duo): + """Test set_props updating component props like Div's children with component.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update Children", id="btn"), + html.Div(id="container"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_children(n): + if not n: + raise PreventUpdate + + set_props( + "container", + { + "children": html.Div( + [ + html.Span(f"Updated {n}"), + html.B(" - Bold Text"), + ] + ) + }, + ) + return f"Children updated {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#container span", "Updated 1", timeout=10) + dash_duo.wait_for_text_to_equal("#container b", "- Bold Text") + dash_duo.wait_for_text_to_equal("#result", "Children updated 1") + + assert dash_duo.get_logs() == [] + + +def test_ws046_set_props_nested_component_children(dash_duo): + """Test set_props with nested component in children prop.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update Nested", id="btn"), + html.Div(id="wrapper"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_nested(n): + if not n: + raise PreventUpdate + + set_props( + "wrapper", + { + "children": html.Div( + [ + html.Ul( + [ + html.Li(f"Item {n}.1"), + html.Li(f"Item {n}.2"), + ] + ) + ] + ) + }, + ) + return f"Nested updated {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal( + "#wrapper ul li:first-child", "Item 1.1", timeout=10 + ) + dash_duo.wait_for_text_to_equal("#wrapper ul li:last-child", "Item 1.2") + dash_duo.wait_for_text_to_equal("#result", "Nested updated 1") + + assert dash_duo.get_logs() == [] + + +def test_ws047_set_props_children_with_list(dash_duo): + """Test set_props with list of components wrapped in a single component.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update List", id="btn"), + html.Div(id="list-container"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_list(n): + if not n: + raise PreventUpdate + + set_props( + "list-container", + { + "children": html.Div( + [ + html.Div(f"Item 1 - {n}"), + html.Div(f"Item 2 - {n}"), + html.Div(f"Item 3 - {n}"), + ] + ) + }, + ) + return f"List updated {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#result", "List updated 1", timeout=10) + assert "Item 1 - 1" in dash_duo.find_element("#list-container").text + assert "Item 2 - 1" in dash_duo.find_element("#list-container").text + assert "Item 3 - 1" in dash_duo.find_element("#list-container").text + + assert dash_duo.get_logs() == [] + + +def test_ws048_set_props_dynamic_match_callback(dash_duo): + """Test set_props injecting components with pattern-matching IDs that trigger MATCH callbacks.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Add Component", id="add-btn"), + html.Div(id="container"), + html.Div("waiting", id="match-result"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("add-btn", "n_clicks")) + async def add_component(n): + if not n: + raise PreventUpdate + + # Inject component with pattern-matching ID via set_props + set_props( + "container", + { + "children": html.Div( + [ + html.Span("Hello"), + html.Button("Click me", id={"type": "dynamic", "index": 0}), + ] + ) + }, + ) + return f"Component added {n}" + + @app.callback( + Output("match-result", "children"), + Input({"type": "dynamic", "index": MATCH}, "n_clicks"), + State({"type": "dynamic", "index": MATCH}, "id"), + prevent_initial_call=True, + ) + def handle_dynamic_click(n_clicks, btn_id): + if not n_clicks: + raise PreventUpdate + return f"Clicked button index {btn_id['index']} - {n_clicks} times" + + dash_duo.start_server(app) + + # Initial state + dash_duo.wait_for_text_to_equal("#match-result", "waiting") + + # Add the dynamic component + dash_duo.find_element("#add-btn").click() + dash_duo.wait_for_text_to_equal("#result", "Component added 1", timeout=10) + + # Verify the component was added + dash_duo.wait_for_text_to_equal("#container span", "Hello", timeout=5) + + # Click the dynamically added button with pattern-matching ID + dash_duo.find_element('[id=\'{"index":0,"type":"dynamic"}\']').click() + + # Verify the MATCH callback fired + dash_duo.wait_for_text_to_equal( + "#match-result", "Clicked button index 0 - 1 times", timeout=10 + ) + + # Click again to verify it continues to work + dash_duo.find_element('[id=\'{"index":0,"type":"dynamic"}\']').click() + dash_duo.wait_for_text_to_equal( + "#match-result", "Clicked button index 0 - 2 times", timeout=10 + ) + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_quart.py b/tests/websocket/test_ws_quart.py new file mode 100644 index 0000000000..30e33b329c --- /dev/null +++ b/tests/websocket/test_ws_quart.py @@ -0,0 +1,228 @@ +""" +Quart WebSocket callback tests. + +Tests the Quart backend websocket implementation which mirrors the FastAPI backend. +""" + +from dash import Dash, html, dcc, Input, Output, State, ctx + + +def test_wsq001_per_callback_websocket_quart(dash_duo): + """Test single callback with websocket=True on Quart backend.""" + app = Dash(__name__, backend="quart") + + app.layout = html.Div( + [ + html.H1("Per-Callback WebSocket Test (Quart)"), + dcc.Input(id="ws-input", type="text", placeholder="Type here..."), + html.Div(id="ws-output"), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"WS: {value or ''}" + + dash_duo.start_server(app) + + # Test initial state (trailing space is trimmed by HTML rendering) + dash_duo.wait_for_text_to_equal("#ws-output", "WS:") + + # Type into the input and verify callback executes + input_elem = dash_duo.find_element("#ws-input") + input_elem.send_keys("hello") + + dash_duo.wait_for_text_to_equal("#ws-output", "WS: hello") + assert dash_duo.get_logs() == [] + + +def test_wsq002_global_websocket_callbacks_quart(dash_duo): + """Test global websocket_callbacks=True enables WebSocket for all callbacks on Quart.""" + app = Dash( + __name__, + backend="quart", + websocket_callbacks=True, + ) + + app.layout = html.Div( + [ + html.Button("Click me", id="btn", n_clicks=0), + html.Div(id="output"), + dcc.Input(id="input", type="text"), + html.Div(id="input-output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks} times" + + @app.callback(Output("input-output", "children"), Input("input", "value")) + def on_input(value): + return f"Input: {value or ''}" + + dash_duo.start_server(app) + + # Test button callback + dash_duo.wait_for_text_to_equal("#output", "Clicked 0 times") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1 times") + + # Test input callback + dash_duo.find_element("#input").send_keys("test") + dash_duo.wait_for_text_to_equal("#input-output", "Input: test") + + assert dash_duo.get_logs() == [] + + +def test_wsq003_mixed_http_and_websocket_quart(dash_duo): + """Test mixing WebSocket and HTTP callbacks in the same app on Quart.""" + app = Dash(__name__, backend="quart") + + app.layout = html.Div( + [ + # WebSocket callback section + html.Div( + [ + dcc.Input(id="ws-input", type="text"), + html.Div(id="ws-output"), + ] + ), + # HTTP callback section (default) + html.Div( + [ + dcc.Input(id="http-input", type="text"), + html.Div(id="http-output"), + ] + ), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"[WebSocket] {value or ''}" + + @app.callback(Output("http-output", "children"), Input("http-input", "value")) + def http_callback(value): + return f"[HTTP] {value or ''}" + + dash_duo.start_server(app) + + # Test WebSocket callback + dash_duo.find_element("#ws-input").send_keys("ws-test") + dash_duo.wait_for_text_to_equal("#ws-output", "[WebSocket] ws-test") + + # Test HTTP callback + dash_duo.find_element("#http-input").send_keys("http-test") + dash_duo.wait_for_text_to_equal("#http-output", "[HTTP] http-test") + + assert dash_duo.get_logs() == [] + + +def test_wsq004_websocket_with_state_quart(dash_duo): + """Test WebSocket callback with State inputs on Quart.""" + app = Dash(__name__, backend="quart", websocket_callbacks=True) + + app.layout = html.Div( + [ + dcc.Input(id="state-input", type="text", value="initial"), + html.Button("Submit", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn", "n_clicks"), + State("state-input", "value"), + ) + def on_click(n_clicks, state_value): + if not n_clicks: + return "Click to submit" + return f"Submitted: {state_value}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to submit") + + # Update state input + state_input = dash_duo.find_element("#state-input") + dash_duo.clear_input(state_input) + state_input.send_keys("new value") + + # Click button to trigger callback + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Submitted: new value") + + assert dash_duo.get_logs() == [] + + +def test_wsq005_websocket_context_available_quart(dash_duo): + """Test that WebSocket context is available in WebSocket callbacks on Quart.""" + app = Dash(__name__, backend="quart", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Check context", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def check_context(n_clicks): + if not n_clicks: + return "Click to check" + ws = ctx.get_websocket + if ws is not None: + return "WebSocket context available" + return "No WebSocket context" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to check") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "WebSocket context available") + + assert dash_duo.get_logs() == [] + + +def test_wsq006_websocket_multiple_outputs_quart(dash_duo): + """Test WebSocket callback with multiple outputs on Quart.""" + app = Dash(__name__, backend="quart", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update", id="btn"), + html.Div(id="output1"), + html.Div(id="output2"), + html.Div(id="output3"), + ] + ) + + @app.callback( + Output("output1", "children"), + Output("output2", "children"), + Output("output3", "children"), + Input("btn", "n_clicks"), + ) + def multi_output(n_clicks): + n = n_clicks or 0 + return f"First: {n}", f"Second: {n * 2}", f"Third: {n * 3}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output1", "First: 0") + dash_duo.wait_for_text_to_equal("#output2", "Second: 0") + dash_duo.wait_for_text_to_equal("#output3", "Third: 0") + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output1", "First: 1") + dash_duo.wait_for_text_to_equal("#output2", "Second: 2") + dash_duo.wait_for_text_to_equal("#output3", "Third: 3") + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_validate.py b/tests/websocket/test_ws_validate.py new file mode 100644 index 0000000000..0a43ada553 --- /dev/null +++ b/tests/websocket/test_ws_validate.py @@ -0,0 +1,58 @@ +import pytest + +from dash.exceptions import WebSocketCallbackError +from dash._validate import validate_websocket_callback_request + + +class TestWebsocketCallbackRequestValidation: + """Tests for runtime WebSocket callback request validation.""" + + def test_global_enabled_allows_any_callback(self): + """When websocket_callbacks=True globally, any callback can use WebSocket.""" + callback_map = { + "out1.children": {"websocket": False}, + "out2.children": {}, # no websocket key + } + # Should not raise - global setting allows all + validate_websocket_callback_request("out1.children", callback_map, True) + validate_websocket_callback_request("out2.children", callback_map, True) + + def test_per_callback_websocket_enabled_passes(self): + """Callback with websocket=True should pass when global is False.""" + callback_map = { + "out1.children": {"websocket": True}, + } + # Should not raise + validate_websocket_callback_request("out1.children", callback_map, False) + + def test_per_callback_websocket_disabled_raises(self): + """Callback without websocket=True should raise when global is False.""" + callback_map = { + "out1.children": {"websocket": False}, + } + + with pytest.raises(WebSocketCallbackError) as exc_info: + validate_websocket_callback_request("out1.children", callback_map, False) + + assert "out1.children" in str(exc_info.value) + assert "websocket=True" in str(exc_info.value) + + def test_callback_without_websocket_key_raises(self): + """Callback without websocket key should raise when global is False.""" + callback_map = { + "out1.children": {}, # no websocket key + } + + with pytest.raises(WebSocketCallbackError) as exc_info: + validate_websocket_callback_request("out1.children", callback_map, False) + + assert "out1.children" in str(exc_info.value) + + def test_unknown_callback_raises(self): + """Unknown callback ID should raise when global is False.""" + callback_map = {} + + with pytest.raises(WebSocketCallbackError) as exc_info: + validate_websocket_callback_request("unknown.children", callback_map, False) + + assert "unknown.children" in str(exc_info.value) diff --git a/wsapp.py b/wsapp.py new file mode 100644 index 0000000000..98b2db2f38 --- /dev/null +++ b/wsapp.py @@ -0,0 +1,106 @@ +""" +Test app for WebSocket-based callbacks. + +Run with: + python wsapp.py + +Then open http://127.0.0.1:8050 in your browser. +""" + +from dash import Dash, html, dcc, callback, Output, Input, ctx +import time + +# Create app with FastAPI backend and WebSocket callbacks enabled +app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, +) + +app.layout = html.Div([ + html.H1("WebSocket Callbacks Test"), + + html.Div([ + html.H3("Basic Callback Test"), + html.Button("Click me", id="btn-1", n_clicks=0), + html.Div(id="output-1"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("Input Test"), + dcc.Input(id="input-1", type="text", placeholder="Type something..."), + html.Div(id="output-2"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("Slider Test"), + dcc.Slider(id="slider-1", min=0, max=100, value=50), + html.Div(id="output-3"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("set_props Test"), + html.Button("Update via set_props", id="btn-2", n_clicks=0), + html.Div(id="output-4", children="Initial content"), + html.Div(id="output-5", children="Will be updated by set_props"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("WebSocket Context Test"), + html.Button("Check WebSocket Context", id="btn-3", n_clicks=0), + html.Div(id="output-6"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div(id="config-display", style={"marginTop": "20px", "fontSize": "12px", "color": "#666"}), +]) + + +@callback(Output("output-1", "children"), Input("btn-1", "n_clicks")) +def update_output_1(n_clicks): + return f"Button clicked {n_clicks} times" + + +@callback(Output("output-2", "children"), Input("input-1", "value")) +def update_output_2(value): + return f"You typed: {value}" + + +@callback(Output("output-3", "children"), Input("slider-1", "value")) +def update_output_3(value): + return f"Slider value: {value}" + + +@callback(Output("output-4", "children"), Input("btn-2", "n_clicks")) +def update_with_set_props(n_clicks): + if n_clicks > 0: + # Use set_props to update another component + from dash._callback_context import set_props + set_props("output-5", {"children": f"Updated via set_props at click {n_clicks}"}) + return f"set_props button clicked {n_clicks} times" + + +@callback(Output("output-6", "children"), Input("btn-3", "n_clicks")) +def check_websocket_context(n_clicks): + if n_clicks > 0: + ws = ctx.get_websocket + if ws is not None: + return f"WebSocket context is available! (click {n_clicks})" + else: + return f"WebSocket context is None (click {n_clicks}) - may be using HTTP fallback" + return "Click to check WebSocket context" + + +@callback(Output("config-display", "children"), Input("btn-1", "n_clicks")) +def show_config(n_clicks): + config = app._config() + ws_config = config.get("websocket", {}) + if ws_config: + return f"WebSocket enabled: {ws_config.get('enabled')}, URL: {ws_config.get('url')}" + return "WebSocket not configured" + + +if __name__ == "__main__": + print("Starting WebSocket callbacks test app...") + print(f"WebSocket callbacks enabled: {app._websocket_callbacks}") + print(f"Backend websocket capability: {app.backend.websocket_capability}") + app.run(debug=True, port=8050) diff --git a/wscb.py b/wscb.py new file mode 100644 index 0000000000..629ed3cdc1 --- /dev/null +++ b/wscb.py @@ -0,0 +1,68 @@ +""" +Test app for per-callback WebSocket support. + +This app demonstrates using websocket=True on specific callbacks +without enabling global websocket_callbacks. +""" + +from dash import Dash, html, dcc, callback, Input, Output, State + +app = Dash(__name__, backend="fastapi") + +app.layout = html.Div([ + html.H1("Per-Callback WebSocket Test"), + + html.Div([ + html.H3("WebSocket Callback"), + dcc.Input(id="ws-input", type="text", placeholder="Type here..."), + html.Div(id="ws-output", style={"padding": "10px", "background": "#e0ffe0"}) + ], style={"margin": "20px", "padding": "20px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("HTTP Callback (default)"), + dcc.Input(id="http-input", type="text", placeholder="Type here..."), + html.Div(id="http-output", style={"padding": "10px", "background": "#e0e0ff"}) + ], style={"margin": "20px", "padding": "20px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("WebSocket Counter"), + html.Button("Increment", id="ws-btn"), + html.Div(id="ws-counter", children="0", style={"padding": "10px", "background": "#ffe0e0"}) + ], style={"margin": "20px", "padding": "20px", "border": "1px solid #ccc"}), +]) + + +@callback( + Output("ws-output", "children"), + Input("ws-input", "value"), + websocket=True +) +def ws_callback(value): + """This callback uses WebSocket.""" + return f"[WebSocket] You typed: {value or ''}" + + +@callback( + Output("http-output", "children"), + Input("http-input", "value") +) +def http_callback(value): + """This callback uses HTTP (default).""" + return f"[HTTP] You typed: {value or ''}" + + +@callback( + Output("ws-counter", "children"), + Input("ws-btn", "n_clicks"), + State("ws-counter", "children"), + websocket=True +) +def ws_counter(n_clicks, current): + """WebSocket counter callback.""" + if n_clicks is None: + return "0" + return str(int(current or 0) + 1) + + +if __name__ == "__main__": + app.run(debug=True)