Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions python/restate/ext/langchain/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from restate.extensions import current_context
from restate.ext.turnstile import Turnstile

from ._state import current_state
from ._state import get_or_create_state, state_from_ctx

ToolCallResult = ToolMessage | Command

Expand Down Expand Up @@ -92,12 +92,16 @@ async def call_model() -> SerializableModelResponse:
if structured_response is not None and isinstance(schema, type) and issubclass(schema, BaseModel):
structured_response = schema.model_validate(structured_response)

# Force tools to run sequentially by setting a turnstile.
# Avoids asyncio.gather() from running in parallel.
# Install this turn's turnstile on ctx.extension_data so sibling
# ``awrap_tool_call`` tasks (spawned by ``tool_node``'s gather) reach
# the same object via the shared Restate ``Context`` — independent of
# ContextVar inheritance. Mirrors how ``restate.ext.adk`` stores its
# PluginState turnstile.
ai_message = next((m for m in journaled.result if isinstance(m, AIMessage)), None)
if ai_message:
state = get_or_create_state(ctx)
if ai_message is not None:
tool_call_ids = [tid for tc in (ai_message.tool_calls or []) if (tid := tc.get("id")) is not None]
current_state().turnstile = Turnstile(tool_call_ids)
state.turnstile = Turnstile(tool_call_ids)

# Turn into ModelResponse as expected by the agent
return ModelResponse(
Expand All @@ -115,8 +119,12 @@ async def awrap_tool_call(
if tool_call_id is None:
return await handler(request)

# Wait for turn and then execute
turnstile = current_state().turnstile
ctx = current_context()
assert ctx is not None, "RestateMiddleware must run inside a Restate handler"
state = state_from_ctx(ctx)
assert state is not None, "RestateMiddleware must run inside a Restate handler"
turnstile = state.turnstile

try:
await turnstile.wait_for(tool_call_id)
result = await handler(request)
Expand Down
40 changes: 35 additions & 5 deletions python/restate/ext/langchain/_state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
# Copyright (c) 2023-2026 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
Expand All @@ -9,20 +9,50 @@
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#

from contextvars import ContextVar
from typing import Optional

from restate import Context
from restate.ext.turnstile import Turnstile
from restate.server_context import get_extension_data, set_extension_data


def _extension_key(invocation_id: str) -> str:
return "restate_langchain_" + invocation_id


class _State:
"""Per-handler middleware state held on the Restate ``Context``.

Stored in ``ctx.extension_data`` (mirrors ``restate.ext.adk``'s
``PluginState``) so every asyncio task spawned during the handler
invocation reaches the same instance via the shared ``Context``
object — independent of ``ContextVar`` inheritance. A ``ContextVar``
binding set deep inside an ``awrap_model_call`` task didn't reach
the sibling ``awrap_tool_call`` tasks spawned by langgraph's
``tool_node``; ``ctx.extension_data`` does.

Holds the current turnstile — replaced on every model call to
describe that turn's batch of tool calls. ``aafter_agent`` resets
it so subsequent agent runs in the same handler start clean.
"""

__slots__ = ("turnstile",)

def __init__(self) -> None:
self.turnstile: Turnstile = Turnstile([])

def __close__(self) -> None:
# Called at handler end via auto_close_extension_data.
self.turnstile.cancel_all()


_state_var: ContextVar[_State] = ContextVar("restate_langchain_state", default=_State())
def get_or_create_state(ctx: Context) -> _State:
state: Optional[_State] = get_extension_data(ctx, "langchain-state")
if state is None:
state = _State()
set_extension_data(ctx, "langchain-state", state)
return state


def current_state() -> _State:
return _state_var.get()
def state_from_ctx(ctx: Context) -> Optional[_State]:
return get_extension_data(ctx, "langchain-state")
Loading