diff --git a/python/restate/ext/langchain/_middleware.py b/python/restate/ext/langchain/_middleware.py index 14eeb88..bc7fbf5 100644 --- a/python/restate/ext/langchain/_middleware.py +++ b/python/restate/ext/langchain/_middleware.py @@ -35,7 +35,7 @@ from restate.extensions import current_context from restate.ext.turnstile import Turnstile -from ._state import current_state +from ._state import get_or_create_state, state_from_ctx ToolCallResult = ToolMessage | Command @@ -92,12 +92,16 @@ async def call_model() -> SerializableModelResponse: if structured_response is not None and isinstance(schema, type) and issubclass(schema, BaseModel): structured_response = schema.model_validate(structured_response) - # Force tools to run sequentially by setting a turnstile. - # Avoids asyncio.gather() from running in parallel. + # Install this turn's turnstile on ctx.extension_data so sibling + # ``awrap_tool_call`` tasks (spawned by ``tool_node``'s gather) reach + # the same object via the shared Restate ``Context`` — independent of + # ContextVar inheritance. Mirrors how ``restate.ext.adk`` stores its + # PluginState turnstile. ai_message = next((m for m in journaled.result if isinstance(m, AIMessage)), None) - if ai_message: + state = get_or_create_state(ctx) + if ai_message is not None: tool_call_ids = [tid for tc in (ai_message.tool_calls or []) if (tid := tc.get("id")) is not None] - current_state().turnstile = Turnstile(tool_call_ids) + state.turnstile = Turnstile(tool_call_ids) # Turn into ModelResponse as expected by the agent return ModelResponse( @@ -115,8 +119,12 @@ async def awrap_tool_call( if tool_call_id is None: return await handler(request) - # Wait for turn and then execute - turnstile = current_state().turnstile + ctx = current_context() + assert ctx is not None, "RestateMiddleware must run inside a Restate handler" + state = state_from_ctx(ctx) + assert state is not None, "RestateMiddleware must run inside a Restate handler" + turnstile = state.turnstile + try: await turnstile.wait_for(tool_call_id) result = await handler(request) diff --git a/python/restate/ext/langchain/_state.py b/python/restate/ext/langchain/_state.py index da2ad3d..9b680d9 100644 --- a/python/restate/ext/langchain/_state.py +++ b/python/restate/ext/langchain/_state.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# Copyright (c) 2023-2026 - Restate Software, Inc., Restate GmbH # # This file is part of the Restate SDK for Python, # which is released under the MIT license. @@ -9,20 +9,50 @@ # https://github.com/restatedev/sdk-typescript/blob/main/LICENSE # -from contextvars import ContextVar +from typing import Optional +from restate import Context from restate.ext.turnstile import Turnstile +from restate.server_context import get_extension_data, set_extension_data + + +def _extension_key(invocation_id: str) -> str: + return "restate_langchain_" + invocation_id class _State: + """Per-handler middleware state held on the Restate ``Context``. + + Stored in ``ctx.extension_data`` (mirrors ``restate.ext.adk``'s + ``PluginState``) so every asyncio task spawned during the handler + invocation reaches the same instance via the shared ``Context`` + object — independent of ``ContextVar`` inheritance. A ``ContextVar`` + binding set deep inside an ``awrap_model_call`` task didn't reach + the sibling ``awrap_tool_call`` tasks spawned by langgraph's + ``tool_node``; ``ctx.extension_data`` does. + + Holds the current turnstile — replaced on every model call to + describe that turn's batch of tool calls. ``aafter_agent`` resets + it so subsequent agent runs in the same handler start clean. + """ + __slots__ = ("turnstile",) def __init__(self) -> None: self.turnstile: Turnstile = Turnstile([]) + def __close__(self) -> None: + # Called at handler end via auto_close_extension_data. + self.turnstile.cancel_all() + -_state_var: ContextVar[_State] = ContextVar("restate_langchain_state", default=_State()) +def get_or_create_state(ctx: Context) -> _State: + state: Optional[_State] = get_extension_data(ctx, "langchain-state") + if state is None: + state = _State() + set_extension_data(ctx, "langchain-state", state) + return state -def current_state() -> _State: - return _state_var.get() +def state_from_ctx(ctx: Context) -> Optional[_State]: + return get_extension_data(ctx, "langchain-state")