From 380d09d1ae5f2fc0dc076ddff8b50e620f853482 Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Thu, 21 May 2026 09:21:44 +0200 Subject: [PATCH 1/3] Fix race condition langchain turnstile --- python/restate/ext/langchain/_state.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/restate/ext/langchain/_state.py b/python/restate/ext/langchain/_state.py index da2ad3d..b235500 100644 --- a/python/restate/ext/langchain/_state.py +++ b/python/restate/ext/langchain/_state.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# Copyright (c) 2023-2026 - Restate Software, Inc., Restate GmbH # # This file is part of the Restate SDK for Python, # which is released under the MIT license. @@ -21,8 +21,15 @@ def __init__(self) -> None: self.turnstile: Turnstile = Turnstile([]) -_state_var: ContextVar[_State] = ContextVar("restate_langchain_state", default=_State()) +# No default: a module-level default `_State()` would be a single object shared +# across every asyncio task that never calls `_state_var.set(...)`. +_state_var: ContextVar[_State] = ContextVar("restate_langchain_state") def current_state() -> _State: - return _state_var.get() + try: + return _state_var.get() + except LookupError: + state = _State() + _state_var.set(state) + return state From 6df68bbcc4bc2f06a2d8eb097ffd868a60548f53 Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Thu, 21 May 2026 12:36:17 +0200 Subject: [PATCH 2/3] Fix race condition langchain turnstile --- python/restate/ext/langchain/_middleware.py | 21 ++++++---- python/restate/ext/langchain/_state.py | 45 ++++++++++++++++----- 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/python/restate/ext/langchain/_middleware.py b/python/restate/ext/langchain/_middleware.py index 14eeb88..10b24ec 100644 --- a/python/restate/ext/langchain/_middleware.py +++ b/python/restate/ext/langchain/_middleware.py @@ -35,7 +35,7 @@ from restate.extensions import current_context from restate.ext.turnstile import Turnstile -from ._state import current_state +from ._state import get_or_create_state, state_from_ctx ToolCallResult = ToolMessage | Command @@ -92,12 +92,16 @@ async def call_model() -> SerializableModelResponse: if structured_response is not None and isinstance(schema, type) and issubclass(schema, BaseModel): structured_response = schema.model_validate(structured_response) - # Force tools to run sequentially by setting a turnstile. - # Avoids asyncio.gather() from running in parallel. + # Install this turn's turnstile on ctx.extension_data so sibling + # ``awrap_tool_call`` tasks (spawned by ``tool_node``'s gather) reach + # the same object via the shared Restate ``Context`` — independent of + # ContextVar inheritance. Mirrors how ``restate.ext.adk`` stores its + # PluginState turnstile. ai_message = next((m for m in journaled.result if isinstance(m, AIMessage)), None) - if ai_message: + state = get_or_create_state(ctx) + if ai_message is not None: tool_call_ids = [tid for tc in (ai_message.tool_calls or []) if (tid := tc.get("id")) is not None] - current_state().turnstile = Turnstile(tool_call_ids) + state.turnstile = Turnstile(tool_call_ids) # Turn into ModelResponse as expected by the agent return ModelResponse( @@ -115,8 +119,11 @@ async def awrap_tool_call( if tool_call_id is None: return await handler(request) - # Wait for turn and then execute - turnstile = current_state().turnstile + ctx = current_context() + state = state_from_ctx(ctx) + assert state is not None, "RestateMiddleware must run inside a Restate handler" + turnstile = state.turnstile + try: await turnstile.wait_for(tool_call_id) result = await handler(request) diff --git a/python/restate/ext/langchain/_state.py b/python/restate/ext/langchain/_state.py index b235500..9b680d9 100644 --- a/python/restate/ext/langchain/_state.py +++ b/python/restate/ext/langchain/_state.py @@ -9,27 +9,50 @@ # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE # -from contextvars import ContextVar +from typing import Optional +from restate import Context from restate.ext.turnstile import Turnstile +from restate.server_context import get_extension_data, set_extension_data + + +def _extension_key(invocation_id: str) -> str: + return "restate_langchain_" + invocation_id class _State: + """Per-handler middleware state held on the Restate ``Context``. + + Stored in ``ctx.extension_data`` (mirrors ``restate.ext.adk``'s + ``PluginState``) so every asyncio task spawned during the handler + invocation reaches the same instance via the shared ``Context`` + object — independent of ``ContextVar`` inheritance. A ``ContextVar`` + binding set deep inside an ``awrap_model_call`` task didn't reach + the sibling ``awrap_tool_call`` tasks spawned by langgraph's + ``tool_node``; ``ctx.extension_data`` does. + + Holds the current turnstile — replaced on every model call to + describe that turn's batch of tool calls. ``aafter_agent`` resets + it so subsequent agent runs in the same handler start clean. + """ + __slots__ = ("turnstile",) def __init__(self) -> None: self.turnstile: Turnstile = Turnstile([]) - -# No default: a module-level default `_State()` would be a single object shared -# across every asyncio task that never calls `_state_var.set(...)`. -_state_var: ContextVar[_State] = ContextVar("restate_langchain_state") + def __close__(self) -> None: + # Called at handler end via auto_close_extension_data. + self.turnstile.cancel_all() -def current_state() -> _State: - try: - return _state_var.get() - except LookupError: +def get_or_create_state(ctx: Context) -> _State: + state: Optional[_State] = get_extension_data(ctx, "langchain-state") + if state is None: state = _State() - _state_var.set(state) - return state + set_extension_data(ctx, "langchain-state", state) + return state + + +def state_from_ctx(ctx: Context) -> Optional[_State]: + return get_extension_data(ctx, "langchain-state") From 798a1555a16c22ad12353aa22109a13fb4ae332d Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Thu, 21 May 2026 12:43:07 +0200 Subject: [PATCH 3/3] Fix race condition langchain turnstile --- python/restate/ext/langchain/_middleware.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/restate/ext/langchain/_middleware.py b/python/restate/ext/langchain/_middleware.py index 10b24ec..bc7fbf5 100644 --- a/python/restate/ext/langchain/_middleware.py +++ b/python/restate/ext/langchain/_middleware.py @@ -120,6 +120,7 @@ async def awrap_tool_call( return await handler(request) ctx = current_context() + assert ctx is not None, "RestateMiddleware must run inside a Restate handler" state = state_from_ctx(ctx) assert state is not None, "RestateMiddleware must run inside a Restate handler" turnstile = state.turnstile