From 965e07367b0a14ee8bc40c6fc408334eafadc77f Mon Sep 17 00:00:00 2001 From: Ian Murray Date: Fri, 12 Jun 2026 08:44:23 +0000 Subject: [PATCH 1/4] Fix missing request context (cookies/headers) in WebSocket callbacks Propagate cookies, headers, args, path, remote, and origin from the WebSocket handshake onto the callback context in create_ws_context, mirroring the HTTP path. This fixes auth helpers (e.g. dash_enterprise_auth.get_user_data) failing over the WebSocket transport because callback_context.cookies/headers were empty. --- dash/backends/_fastapi.py | 14 ++++++++ dash/backends/_quart.py | 14 ++++++++ dash/backends/ws.py | 23 +++++++++++- tests/websocket/test_ws_context.py | 56 ++++++++++++++++++++++++++++++ 4 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 tests/websocket/test_ws_context.py diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 0516b5edcb..5167922dab 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -725,6 +725,19 @@ async def websocket_handler(websocket: WebSocket): await websocket.accept() + # Capture request metadata from the WebSocket handshake once per + # connection so that callbacks running over the WebSocket transport + # can access cookies/headers (e.g. for authentication helpers such + # as dash_enterprise_auth.get_user_data). + request_context = { + "cookies": dict(websocket.cookies), + "headers": dict(websocket.headers), + "args": dict(websocket.query_params), + "path": websocket.url.path, + "remote": websocket.client.host if websocket.client else "", + "origin": websocket.headers.get("origin", ""), + } + # Create janus queue for outbound messages (main loop context) outbound_queue: janus.Queue[str] = janus.Queue() # Track pending get_props requests with standard queue.Queue for responses @@ -788,6 +801,7 @@ async def websocket_handler(websocket: WebSocket): payload, ws_cb, FastAPIResponseAdapter(), + request_context, ) # Set up done callback to send response diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index c7634ce93a..f827d5115a 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -545,6 +545,19 @@ async def websocket_handler(): # pylint: disable=too-many-branches await ws.accept() + # Capture request metadata from the WebSocket handshake once per + # connection so that callbacks running over the WebSocket transport + # can access cookies/headers (e.g. for authentication helpers such + # as dash_enterprise_auth.get_user_data). + request_context = { + "cookies": dict(ws.cookies), + "headers": dict(ws.headers), + "args": dict(ws.args), + "path": ws.path, + "remote": ws.remote_addr, + "origin": ws.headers.get("origin", ""), + } + # Track this connection for graceful shutdown try: ws_obj = ws._get_current_object() @@ -623,6 +636,7 @@ async def websocket_handler(): # pylint: disable=too-many-branches payload, ws_cb, QuartResponseAdapter(), + request_context, ) # Set up done callback to send response diff --git a/dash/backends/ws.py b/dash/backends/ws.py index a4b302f215..c0fe3da4d7 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -189,6 +189,7 @@ def create_ws_context( payload: dict, response_adapter: "ResponseAdapter", websocket_callback: DashWebsocketCallback, + request_context: "dict | None" = None, ): """Create callback context from WebSocket message. @@ -196,6 +197,12 @@ def create_ws_context( payload: The callback payload response_adapter: The response adapter instance for the backend websocket_callback: The websocket callback instance for the backend + request_context: Optional request metadata (cookies, headers, args, + path, remote, origin) captured from the WebSocket handshake. This + mirrors the context populated for regular HTTP callbacks so that + ``callback_context.cookies``/``headers`` (and downstream helpers + such as ``dash_enterprise_auth.get_user_data``) work inside + WebSocket callbacks. Returns: AttributeDict with callback context @@ -217,6 +224,14 @@ def create_ws_context( g.updated_props = {} g.dash_websocket = websocket_callback + request_context = request_context or {} + g.cookies = request_context.get("cookies", {}) + g.headers = request_context.get("headers", {}) + g.args = request_context.get("args", "") + g.path = request_context.get("path", "") + g.remote = request_context.get("remote", "") + g.origin = request_context.get("origin", "") + return g @@ -396,6 +411,7 @@ def run_callback_in_executor( payload: dict, ws_callback: DashWebsocketCallback, response_adapter: "ResponseAdapter", + request_context: "dict | None" = None, ) -> concurrent.futures.Future: """Submit callback to executor for thread pool execution. @@ -408,6 +424,9 @@ def run_callback_in_executor( payload: The callback payload from WebSocket message ws_callback: WebSocket callback instance for set_prop/get_prop response_adapter: Response adapter for the backend + request_context: Optional request metadata (cookies, headers, args, + path, remote, origin) captured from the WebSocket handshake, made + available on the callback context. Returns: Future representing the pending callback execution @@ -415,7 +434,9 @@ def run_callback_in_executor( def execute() -> dict: try: - cb_ctx = create_ws_context(payload, response_adapter, ws_callback) + cb_ctx = create_ws_context( + payload, response_adapter, ws_callback, request_context + ) # pylint: disable=protected-access func = dash_app._prepare_callback(cb_ctx, payload) args = dash_app._inputs_to_vals( # pylint: disable=protected-access diff --git a/tests/websocket/test_ws_context.py b/tests/websocket/test_ws_context.py new file mode 100644 index 0000000000..1b77d30fe8 --- /dev/null +++ b/tests/websocket/test_ws_context.py @@ -0,0 +1,56 @@ +"""Unit tests for WebSocket callback context creation. + +These tests verify that request metadata captured from the WebSocket +handshake (cookies, headers, etc.) is propagated onto the callback +context. This is required so authentication helpers that read +``callback_context.cookies``/``headers`` (such as +``dash_enterprise_auth.get_user_data``) work inside WebSocket callbacks. +""" + +from dash.backends.ws import create_ws_context + + +def test_create_ws_context_propagates_request_context(): + """Request metadata should be copied onto the callback context.""" + payload = { + "inputs": [], + "state": [], + "outputs": [], + "changedPropIds": [], + } + request_context = { + "cookies": {"kcIdToken": "token-value"}, + "headers": {"Plotly-User-Data": "{}"}, + "args": {"foo": "bar"}, + "path": "/_dash-ws-callback", + "remote": "10.0.0.1", + "origin": "https://example.com", + } + + g = create_ws_context(payload, response_adapter=None, websocket_callback=None, request_context=request_context) + + assert g.cookies == {"kcIdToken": "token-value"} + assert g.headers == {"Plotly-User-Data": "{}"} + assert g.args == {"foo": "bar"} + assert g.path == "/_dash-ws-callback" + assert g.remote == "10.0.0.1" + assert g.origin == "https://example.com" + + +def test_create_ws_context_defaults_without_request_context(): + """Context should expose empty defaults when no request context is given.""" + payload = { + "inputs": [], + "state": [], + "outputs": [], + "changedPropIds": [], + } + + g = create_ws_context(payload, response_adapter=None, websocket_callback=None) + + assert g.cookies == {} + assert g.headers == {} + assert g.args == "" + assert g.path == "" + assert g.remote == "" + assert g.origin == "" From 3cd9addbfb337e8e8c4c613646d2f9aad33c767a Mon Sep 17 00:00:00 2001 From: Ian MacDougall Murray Date: Mon, 15 Jun 2026 09:34:26 +0000 Subject: [PATCH 2/4] Resolve lint warnings in modified files --- dash/backends/_fastapi.py | 3 ++- dash/backends/ws.py | 1 + tests/websocket/test_ws_context.py | 7 ++++++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 5167922dab..c7188420b3 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -328,7 +328,8 @@ def setup_catchall(self, dash_app: Dash): and passed through the middleware, which is necessary for features like authentication and timing to work correctly on all routes. FastAPI will match this catch-all route for any path that isn't matched by a more specific route, allowing the middleware to - process the request and then return the appropriate response (e.g., 404 if no Dash route matches).""" + process the request and then return the appropriate response (e.g., 404 if no Dash route matches). + """ def _setup_catchall(self): try: diff --git a/dash/backends/ws.py b/dash/backends/ws.py index c0fe3da4d7..6df9a6e15e 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -3,6 +3,7 @@ This module provides the WebSocket callback infrastructure for real-time bidirectional communication between Dash backends and the renderer. """ + from __future__ import annotations import asyncio diff --git a/tests/websocket/test_ws_context.py b/tests/websocket/test_ws_context.py index 1b77d30fe8..45bc758191 100644 --- a/tests/websocket/test_ws_context.py +++ b/tests/websocket/test_ws_context.py @@ -27,7 +27,12 @@ def test_create_ws_context_propagates_request_context(): "origin": "https://example.com", } - g = create_ws_context(payload, response_adapter=None, websocket_callback=None, request_context=request_context) + g = create_ws_context( + payload, + response_adapter=None, + websocket_callback=None, + request_context=request_context, + ) assert g.cookies == {"kcIdToken": "token-value"} assert g.headers == {"Plotly-User-Data": "{}"} From 84af61ca6ea6bcabb279bc3bbda5e75b11ec9eec Mon Sep 17 00:00:00 2001 From: Ian MacDougall Murray Date: Mon, 15 Jun 2026 09:36:24 +0000 Subject: [PATCH 3/4] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e0404e123..7c990cd053 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ This project adheres to [Semantic Versioning](https://semver.org/). ### Fixed - [#3805](https://github.com/plotly/dash/pull/3805) Fix FastAPI POST routes deadlock caused by middleware consuming request body. Fixes [#3801](https://github.com/plotly/dash/issues/3801). +- [#3815](https://github.com/plotly/dash/pull/3815) Fix missing request context (cookies/headers) in websocket callbacks. ## [4.2.0] - 2026-06-01 - *The Freedom Update* From 3795bc4ce53e9eb65d697bf40325ccf86f71e1df Mon Sep 17 00:00:00 2001 From: Ian MacDougall Murray Date: Thu, 18 Jun 2026 18:26:25 +0000 Subject: [PATCH 4/4] Reuse set_current_request machinery for WebSocket callback context Populate callback_context request metadata (cookies/headers/args/path/ remote/origin) for WebSocket callbacks the same way as HTTP callbacks, via a shared populate_request_metadata helper and an activate_request seam (FastAPI uses set_current_request; Quart uses a handshake snapshot). Adds unit and end-to-end tests for FastAPI and Quart. --- dash/_utils.py | 17 ++++++ dash/backends/_fastapi.py | 29 ++++----- dash/backends/_quart.py | 34 ++++++----- dash/backends/ws.py | 91 ++++++++++++++++------------- dash/dash.py | 8 +-- tests/websocket/test_ws_basic.py | 28 +++++++++ tests/websocket/test_ws_context.py | 94 +++++++++++++++++++++++++----- tests/websocket/test_ws_quart.py | 34 +++++++++++ 8 files changed, 250 insertions(+), 85 deletions(-) diff --git a/dash/_utils.py b/dash/_utils.py index 1ab2036820..c6a827f10d 100644 --- a/dash/_utils.py +++ b/dash/_utils.py @@ -218,6 +218,23 @@ def inputs_to_dict(inputs_list): return inputs +def populate_request_metadata(g, adapter): + """Copy request metadata from a request adapter onto a context object. + + Shared by the HTTP path (``Dash._initialize_context``) and the WebSocket + path (``dash.backends.ws.create_ws_context``) so that both transports expose + identical request context (cookies, headers, args, path, remote, origin) on + ``callback_context``. + """ + g.cookies = dict(adapter.cookies) + g.headers = dict(adapter.headers) + g.args = adapter.args + g.path = adapter.full_path + g.remote = adapter.remote_addr + g.origin = adapter.origin + return g + + def convert_to_AttributeDict(nested_list): new_dict = [] for i in nested_list: diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index c7188420b3..6c573e0101 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import contextmanager from contextvars import copy_context, ContextVar import asyncio import concurrent.futures @@ -726,18 +727,20 @@ async def websocket_handler(websocket: WebSocket): await websocket.accept() - # Capture request metadata from the WebSocket handshake once per - # connection so that callbacks running over the WebSocket transport - # can access cookies/headers (e.g. for authentication helpers such - # as dash_enterprise_auth.get_user_data). - request_context = { - "cookies": dict(websocket.cookies), - "headers": dict(websocket.headers), - "args": dict(websocket.query_params), - "path": websocket.url.path, - "remote": websocket.client.host if websocket.client else "", - "origin": websocket.headers.get("origin", ""), - } + # Activate the WebSocket handshake request using the same + # request-context machinery as HTTP callbacks (set_current_request) + # so that callbacks running over the WebSocket transport can access + # cookies/headers (e.g. for authentication helpers such as + # dash_enterprise_auth.get_user_data). ContextVars do not propagate + # into the executor threads, so activation happens inside the worker + # thread via this callable. + @contextmanager + def activate_request(): + token = set_current_request(websocket) + try: + yield FastAPIRequestAdapter() + finally: + reset_current_request(token) # Create janus queue for outbound messages (main loop context) outbound_queue: janus.Queue[str] = janus.Queue() @@ -802,7 +805,7 @@ async def websocket_handler(websocket: WebSocket): payload, ws_cb, FastAPIResponseAdapter(), - request_context, + activate_request, ) # Set up done callback to send response diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index f827d5115a..16e7a77f6c 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -12,8 +12,10 @@ import threading from urllib.parse import urlparse +from contextlib import contextmanager from logging.config import dictConfig from contextvars import copy_context +from types import SimpleNamespace from typing import Any, Dict, TYPE_CHECKING from importlib_metadata import version as _get_distribution_version @@ -545,18 +547,24 @@ async def websocket_handler(): # pylint: disable=too-many-branches await ws.accept() - # Capture request metadata from the WebSocket handshake once per - # connection so that callbacks running over the WebSocket transport - # can access cookies/headers (e.g. for authentication helpers such - # as dash_enterprise_auth.get_user_data). - request_context = { - "cookies": dict(ws.cookies), - "headers": dict(ws.headers), - "args": dict(ws.args), - "path": ws.path, - "remote": ws.remote_addr, - "origin": ws.headers.get("origin", ""), - } + # Quart's request/websocket context cannot cross into the executor + # threads where callbacks run, so snapshot the handshake metadata + # here (where the ``websocket`` proxy is valid) into an adapter-shaped + # object. It is funnelled through the same ``populate_request_metadata`` + # helper as HTTP callbacks so the context (cookies, headers, args, + # path, remote, origin) is populated identically. + request_snapshot = SimpleNamespace( + cookies=ws.cookies, + headers=ws.headers, + args=ws.args, + full_path=ws.full_path, + remote_addr=ws.remote_addr, + origin=ws.headers.get("origin"), + ) + + @contextmanager + def activate_request(): + yield request_snapshot # Track this connection for graceful shutdown try: @@ -636,7 +644,7 @@ async def websocket_handler(): # pylint: disable=too-many-branches payload, ws_cb, QuartResponseAdapter(), - request_context, + activate_request, ) # Set up done callback to send response diff --git a/dash/backends/ws.py b/dash/backends/ws.py index 6df9a6e15e..55d5c5ccbb 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -9,6 +9,7 @@ import asyncio import concurrent.futures from concurrent.futures import ThreadPoolExecutor +from contextlib import nullcontext import inspect import json import queue @@ -190,7 +191,7 @@ def create_ws_context( payload: dict, response_adapter: "ResponseAdapter", websocket_callback: DashWebsocketCallback, - request_context: "dict | None" = None, + request_adapter: Any = None, ): """Create callback context from WebSocket message. @@ -198,18 +199,20 @@ def create_ws_context( payload: The callback payload response_adapter: The response adapter instance for the backend websocket_callback: The websocket callback instance for the backend - request_context: Optional request metadata (cookies, headers, args, - path, remote, origin) captured from the WebSocket handshake. This - mirrors the context populated for regular HTTP callbacks so that - ``callback_context.cookies``/``headers`` (and downstream helpers - such as ``dash_enterprise_auth.get_user_data``) work inside - WebSocket callbacks. + request_adapter: Optional request adapter (or any object exposing + ``cookies``/``headers``/``args``/``full_path``/``remote_addr``/ + ``origin``) captured from the WebSocket handshake. When provided, + the request metadata is copied onto the context the same way as for + regular HTTP callbacks so that ``callback_context.cookies``/ + ``headers`` (and downstream helpers such as + ``dash_enterprise_auth.get_user_data``) work inside WebSocket + callbacks. Returns: AttributeDict with callback context """ # pylint: disable=import-outside-toplevel - from dash._utils import AttributeDict, inputs_to_dict + from dash._utils import AttributeDict, inputs_to_dict, populate_request_metadata g = AttributeDict({}) g.inputs_list = payload.get("inputs", []) @@ -225,13 +228,15 @@ def create_ws_context( g.updated_props = {} g.dash_websocket = websocket_callback - request_context = request_context or {} - g.cookies = request_context.get("cookies", {}) - g.headers = request_context.get("headers", {}) - g.args = request_context.get("args", "") - g.path = request_context.get("path", "") - g.remote = request_context.get("remote", "") - g.origin = request_context.get("origin", "") + if request_adapter is not None: + populate_request_metadata(g, request_adapter) + else: + g.cookies = {} + g.headers = {} + g.args = {} + g.path = "" + g.remote = "" + g.origin = "" return g @@ -412,7 +417,7 @@ def run_callback_in_executor( payload: dict, ws_callback: DashWebsocketCallback, response_adapter: "ResponseAdapter", - request_context: "dict | None" = None, + activate_request: "Callable[[], Any] | None" = None, ) -> concurrent.futures.Future: """Submit callback to executor for thread pool execution. @@ -425,9 +430,13 @@ def run_callback_in_executor( payload: The callback payload from WebSocket message ws_callback: WebSocket callback instance for set_prop/get_prop response_adapter: Response adapter for the backend - request_context: Optional request metadata (cookies, headers, args, - path, remote, origin) captured from the WebSocket handshake, made - available on the callback context. + activate_request: Optional zero-argument callable returning a context + manager that activates the WebSocket handshake request *inside the + worker thread* and yields a request adapter (or ``None``). This + lets each backend reuse its own request-context machinery (e.g. + ``set_current_request`` for FastAPI) so the callback context is + populated the same way as for HTTP callbacks. ContextVars do not + propagate into executor threads, so activation must happen here. Returns: Future representing the pending callback execution @@ -435,31 +444,33 @@ def run_callback_in_executor( def execute() -> dict: try: - cb_ctx = create_ws_context( - payload, response_adapter, ws_callback, request_context - ) - # pylint: disable=protected-access - func = dash_app._prepare_callback(cb_ctx, payload) - args = dash_app._inputs_to_vals( # pylint: disable=protected-access - cb_ctx.inputs_list + cb_ctx.states_list - ) + request_cm = activate_request() if activate_request else nullcontext() + with request_cm as request_adapter: + cb_ctx = create_ws_context( + payload, response_adapter, ws_callback, request_adapter + ) + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, payload) + args = dash_app._inputs_to_vals( # pylint: disable=protected-access + cb_ctx.inputs_list + cb_ctx.states_list + ) - ctx = copy_context() - partial_func = ( - dash_app._execute_callback( # pylint: disable=protected-access - func, args, cb_ctx.outputs_list, cb_ctx + ctx = copy_context() + partial_func = ( + dash_app._execute_callback( # pylint: disable=protected-access + func, args, cb_ctx.outputs_list, cb_ctx + ) ) - ) - # Run in new event loop (handles both sync and async callbacks) - def run_callback(): - result = partial_func() - if inspect.iscoroutine(result): - return asyncio.run(result) - return result + # Run in new event loop (handles both sync and async callbacks) + def run_callback(): + result = partial_func() + if inspect.iscoroutine(result): + return asyncio.run(result) + return result - response_data = ctx.run(run_callback) - return {"status": "ok", "data": json.loads(response_data)} + response_data = ctx.run(run_callback) + return {"status": "ok", "data": json.loads(response_data)} except PreventUpdate: return {"status": "prevent_update"} diff --git a/dash/dash.py b/dash/dash.py index f547b95b56..4f1c3a892f 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -46,6 +46,7 @@ inputs_to_vals, interpolate_str, patch_collections_abc, + populate_request_metadata, split_callback_id, to_json, convert_to_AttributeDict, @@ -1494,12 +1495,7 @@ def _initialize_context(self, body): for x in body.get("changedPropIds", []) ] g.dash_response = self.backend.response_adapter() - g.cookies = dict(adapter.cookies) - g.headers = dict(adapter.headers) - g.args = adapter.args - g.path = adapter.full_path - g.remote = adapter.remote_addr - g.origin = adapter.origin + populate_request_metadata(g, adapter) g.updated_props = {} return g diff --git a/tests/websocket/test_ws_basic.py b/tests/websocket/test_ws_basic.py index 1d74706a68..136d96db46 100644 --- a/tests/websocket/test_ws_basic.py +++ b/tests/websocket/test_ws_basic.py @@ -252,3 +252,31 @@ def update_output(value): dash_duo.wait_for_text_to_equal("#output", "Slider value: 50") assert dash_duo.get_logs() == [] + + +def test_ws008_websocket_request_context_cookies(dash_duo): + """WebSocket callbacks should expose request cookies/headers on ctx (FastAPI).""" + app = Dash(__name__, backend="fastapi") + + app.layout = html.Div( + [ + dcc.Input(id="ws-input", type="text"), + html.Div(id="ws-output"), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def show_context(value): + return f"cookie={ctx.cookies.get('wscookie', '')} headers={bool(ctx.headers)}" + + dash_duo.start_server(app) + + # Set a cookie, then reload so the WebSocket handshake carries it. + dash_duo.driver.add_cookie({"name": "wscookie", "value": "wsval"}) + dash_duo.driver.refresh() + + dash_duo.wait_for_text_to_equal("#ws-output", "cookie=wsval headers=True") + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_context.py b/tests/websocket/test_ws_context.py index 45bc758191..3ff72ee3ff 100644 --- a/tests/websocket/test_ws_context.py +++ b/tests/websocket/test_ws_context.py @@ -7,31 +7,35 @@ ``dash_enterprise_auth.get_user_data``) work inside WebSocket callbacks. """ +from types import SimpleNamespace + +import pytest + from dash.backends.ws import create_ws_context def test_create_ws_context_propagates_request_context(): - """Request metadata should be copied onto the callback context.""" + """Request metadata from the adapter should be copied onto the context.""" payload = { "inputs": [], "state": [], "outputs": [], "changedPropIds": [], } - request_context = { - "cookies": {"kcIdToken": "token-value"}, - "headers": {"Plotly-User-Data": "{}"}, - "args": {"foo": "bar"}, - "path": "/_dash-ws-callback", - "remote": "10.0.0.1", - "origin": "https://example.com", - } + request_adapter = SimpleNamespace( + cookies={"kcIdToken": "token-value"}, + headers={"Plotly-User-Data": "{}"}, + args={"foo": "bar"}, + full_path="/_dash-ws-callback", + remote_addr="10.0.0.1", + origin="https://example.com", + ) g = create_ws_context( payload, response_adapter=None, websocket_callback=None, - request_context=request_context, + request_adapter=request_adapter, ) assert g.cookies == {"kcIdToken": "token-value"} @@ -42,8 +46,8 @@ def test_create_ws_context_propagates_request_context(): assert g.origin == "https://example.com" -def test_create_ws_context_defaults_without_request_context(): - """Context should expose empty defaults when no request context is given.""" +def test_create_ws_context_defaults_without_request_adapter(): + """Context should expose empty defaults when no request adapter is given.""" payload = { "inputs": [], "state": [], @@ -55,7 +59,71 @@ def test_create_ws_context_defaults_without_request_context(): assert g.cookies == {} assert g.headers == {} - assert g.args == "" + assert g.args == {} assert g.path == "" assert g.remote == "" assert g.origin == "" + + +def test_run_executor_activates_request_across_thread_boundary(): + """Request activation must populate context inside the executor thread. + + WebSocket callbacks run in a ``ThreadPoolExecutor`` and ContextVars do not + propagate into those threads, so the refactor activates the handshake + request (via FastAPI's ``set_current_request``) *inside* the worker thread. + This guards that seam: a request activated in the worker thread is visible + to ``FastAPIRequestAdapter`` and gets copied onto the callback context. + """ + pytest.importorskip("fastapi") + from concurrent.futures import ThreadPoolExecutor + from contextlib import contextmanager + + from dash.backends._fastapi import ( + FastAPIRequestAdapter, + reset_current_request, + set_current_request, + ) + + # Minimal stand-in for a Starlette ``WebSocket`` handshake connection: it + # only needs the attributes the request adapter reads. + handshake = SimpleNamespace( + cookies={"kcIdToken": "token-value"}, + headers={"origin": "https://example.com"}, + query_params={"foo": "bar"}, + url="http://testserver/_dash-ws-callback", + client=SimpleNamespace(host="10.0.0.1"), + ) + + @contextmanager + def activate_request(): + token = set_current_request(handshake) + try: + yield FastAPIRequestAdapter() + finally: + reset_current_request(token) + + payload = { + "inputs": [], + "state": [], + "outputs": [], + "changedPropIds": [], + } + + def worker(): + with activate_request() as request_adapter: + return create_ws_context( + payload, + response_adapter=None, + websocket_callback=None, + request_adapter=request_adapter, + ) + + with ThreadPoolExecutor(max_workers=1) as executor: + g = executor.submit(worker).result() + + assert g.cookies == {"kcIdToken": "token-value"} + assert g.headers == {"origin": "https://example.com"} + assert g.args == {"foo": "bar"} + assert g.path == "http://testserver/_dash-ws-callback" + assert g.remote == "10.0.0.1" + assert g.origin == "https://example.com" diff --git a/tests/websocket/test_ws_quart.py b/tests/websocket/test_ws_quart.py index 3d40493ba5..f423f6a637 100644 --- a/tests/websocket/test_ws_quart.py +++ b/tests/websocket/test_ws_quart.py @@ -226,3 +226,37 @@ def multi_output(n_clicks): dash_duo.wait_for_text_to_equal("#output3", "Third: 3") assert dash_duo.get_logs() == [] + + +def test_wsq007_websocket_request_context_cookies_quart(dash_duo): + """WebSocket callbacks should expose request cookies/headers on ctx (Quart). + + End-to-end regression test for https://github.com/plotly/dash/issues/3814: + request metadata from the WebSocket handshake (cookies, headers) must be + available on ``callback_context`` for ``websocket=True`` callbacks, just as + it is for HTTP callbacks. + """ + app = Dash(__name__, backend="quart") + + app.layout = html.Div( + [ + dcc.Input(id="ws-input", type="text"), + html.Div(id="ws-output"), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def show_context(value): + return f"cookie={ctx.cookies.get('wscookie', '')} headers={bool(ctx.headers)}" + + dash_duo.start_server(app) + + # Set a cookie, then reload so the WebSocket handshake carries it. + dash_duo.driver.add_cookie({"name": "wscookie", "value": "wsval"}) + dash_duo.driver.refresh() + + dash_duo.wait_for_text_to_equal("#ws-output", "cookie=wsval headers=True") + + assert dash_duo.get_logs() == []