From 0b949ca77c3975acb91b5d6e133477acb031e0be Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Tue, 21 Apr 2026 15:57:49 +0800 Subject: [PATCH 1/3] feat: try to support langchain-ai/agent-protocol Signed-off-by: Chojan Shang --- contrib/bubseek-langchain/README.md | 12 ++ contrib/bubseek-langchain/pyproject.toml | 1 + .../src/bubseek_langchain/__init__.py | 6 +- .../src/bubseek_langchain/agent_protocol.py | 176 ++++++++++++++++++ .../src/bubseek_langchain/config.py | 26 +++ .../src/bubseek_langchain/normalize.py | 4 + .../tests/test_agent_protocol.py | 173 +++++++++++++++++ examples/langchain/README.md | 37 ++++ examples/langchain/remote_agent_protocol.py | 21 +++ uv.lock | 2 + 10 files changed, 457 insertions(+), 1 deletion(-) create mode 100644 contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py create mode 100644 contrib/bubseek-langchain/tests/test_agent_protocol.py create mode 100644 examples/langchain/remote_agent_protocol.py diff --git a/contrib/bubseek-langchain/README.md b/contrib/bubseek-langchain/README.md index c3eb757..b1d5eb1 100644 --- a/contrib/bubseek-langchain/README.md +++ b/contrib/bubseek-langchain/README.md @@ -6,6 +6,7 @@ Current scope: - only `Runnable` mode is supported; - Bub tools can be bridged into LangChain tools; +- remote agent-protocol services can be wrapped through `langgraph-sdk`; - Bub tape recording still works for user / assistant turns and tool spans; - prompts starting with `,` still fall through to Bub built-in internal commands. @@ -58,6 +59,7 @@ Repository examples live under [examples/langchain/README.md](/home/shangzhuoran - [examples/langchain/minimal_runnable.py](/home/shangzhuoran.szr/oceanbase/bubseek/examples/langchain/minimal_runnable.py) - [examples/langchain/deepagents_dashscope.py](/home/shangzhuoran.szr/oceanbase/bubseek/examples/langchain/deepagents_dashscope.py) +- [examples/langchain/remote_agent_protocol.py](/home/shangzhuoran.szr/oceanbase/bubseek/examples/langchain/remote_agent_protocol.py) Typical minimal run: @@ -66,3 +68,13 @@ export BUB_LANGCHAIN_MODE=runnable export BUB_LANGCHAIN_FACTORY=examples.langchain.minimal_runnable:minimal_lc_agent uv run bub run "Summarize this workspace in one sentence." ``` + +Typical remote agent-protocol run: + +```bash +export BUB_LANGCHAIN_MODE=runnable +export BUB_LANGCHAIN_FACTORY=examples.langchain.remote_agent_protocol:remote_agent_protocol_agent +export BUB_AGENT_PROTOCOL_URL=http://localhost:2024 +export BUB_AGENT_PROTOCOL_AGENT_ID=agent +uv run bub chat +``` diff --git a/contrib/bubseek-langchain/pyproject.toml b/contrib/bubseek-langchain/pyproject.toml index 690c75a..b553637 100644 --- a/contrib/bubseek-langchain/pyproject.toml +++ b/contrib/bubseek-langchain/pyproject.toml @@ -7,6 +7,7 @@ authors = [{ name = "Chojan Shang", email = "psiace@apache.org" }] requires-python = ">=3.12" dependencies = [ "bub", + "langgraph-sdk>=0.3.13", "langchain-core>=0.3.0", "pydantic>=2.0", "pydantic-settings>=2.0.0", diff --git a/contrib/bubseek-langchain/src/bubseek_langchain/__init__.py b/contrib/bubseek-langchain/src/bubseek_langchain/__init__.py index 2d83918..f1a3257 100644 --- a/contrib/bubseek-langchain/src/bubseek_langchain/__init__.py +++ b/contrib/bubseek-langchain/src/bubseek_langchain/__init__.py @@ -1,17 +1,21 @@ """LangChain Runnable adapter for Bubseek.""" +from .agent_protocol import AgentProtocolRunnable from .bridge import LangchainFactoryRequest, LangchainRunContext, RunnableBinding -from .config import LangchainPluginSettings, load_settings +from .config import AgentProtocolSettings, LangchainPluginSettings, load_agent_protocol_settings, load_settings from .errors import LangchainConfigError from .plugin import LangchainPlugin, main __all__ = [ + "AgentProtocolRunnable", + "AgentProtocolSettings", "LangchainConfigError", "LangchainFactoryRequest", "LangchainPlugin", "LangchainPluginSettings", "LangchainRunContext", "RunnableBinding", + "load_agent_protocol_settings", "load_settings", "main", ] diff --git a/contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py b/contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py new file mode 100644 index 0000000..40524fe --- /dev/null +++ b/contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import asyncio +import hashlib +from collections.abc import AsyncIterator, Mapping +from typing import Any + +from langchain_core.runnables import Runnable +from loguru import logger + +from .bridge import LangchainRunContext, extract_prompt_text +from .config import AgentProtocolSettings +from .normalize import normalize_langchain_output + + +def _bind_logger(run_context: LangchainRunContext | None): + if run_context is None: + return logger + return logger.bind(**run_context.as_logger_extra()) + + +def _message_role(message: Mapping[str, Any]) -> str | None: + for key in ("role", "type"): + value = message.get(key) + if isinstance(value, str) and value.strip(): + return value.strip().lower() + return None + + +def _is_assistant_message(message: Mapping[str, Any]) -> bool: + role = _message_role(message) + if role is None: + return True + return role in {"assistant", "ai", "aimessage", "aimessagechunk"} + + +def _messages_from_stream_part(part: Any) -> list[Mapping[str, Any]]: + if hasattr(part, "event") and hasattr(part, "data"): + event = part.event + data = part.data + if event == "messages" and isinstance(data, list) and data: + first = data[0] + return [first] if isinstance(first, Mapping) else [] + if event in {"messages/partial", "messages/complete"} and isinstance(data, list): + return [item for item in data if isinstance(item, Mapping)] + return [] + + if not isinstance(part, dict): + return [] + part_type = part.get("type") or part.get("event") + data = part.get("data") + if part_type == "messages" and isinstance(data, list) and data: + first = data[0] + return [first] if isinstance(first, Mapping) else [] + if part_type in {"messages/partial", "messages/complete"} and isinstance(data, list): + return [item for item in data if isinstance(item, Mapping)] + return [] + + +def _final_state_from_stream_part(part: Any) -> Any | None: + if hasattr(part, "event") and hasattr(part, "data"): + return part.data if part.event == "values" else None + if isinstance(part, dict) and (part.get("type") == "values" or part.get("event") == "values"): + return part.get("data") + return None + + +class AgentProtocolRunnable(Runnable[Any, Any]): + """Wrap a remote agent-protocol server as a LangChain-compatible Runnable.""" + + def __init__( + self, + *, + settings: AgentProtocolSettings, + session_id: str | None, + langchain_context: LangchainRunContext | None = None, + ) -> None: + self._settings = settings + self._session_id = session_id + self._langchain_context = langchain_context + self._logger = _bind_logger(langchain_context) + self._client: Any | None = None + + def invoke(self, value: Any, config: dict[str, Any] | None = None, **kwargs: Any) -> Any: + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(self.ainvoke(value, config=config, **kwargs)) + raise RuntimeError("AgentProtocolRunnable.invoke cannot be used from a running event loop; use ainvoke instead") + + async def ainvoke(self, value: Any, config: dict[str, Any] | None = None, **kwargs: Any) -> Any: + thread_id = await self._resolve_thread_id() + run_input = self._build_run_input(value) + metadata = self._build_metadata() + self._logger.debug( + "Invoking remote agent-protocol agent={} stateful={} thread_id={}", + self._settings.agent_id, + self._settings.stateful, + thread_id, + ) + return await self._client_instance().runs.wait( + thread_id=thread_id, + assistant_id=self._settings.agent_id, + input=run_input, + metadata=metadata, + if_not_exists="create" if thread_id is not None else None, + ) + + async def astream(self, value: Any, config: dict[str, Any] | None = None, **kwargs: Any) -> AsyncIterator[str]: + thread_id = await self._resolve_thread_id() + run_input = self._build_run_input(value) + metadata = self._build_metadata() + self._logger.debug( + "Streaming remote agent-protocol agent={} stateful={} thread_id={}", + self._settings.agent_id, + self._settings.stateful, + thread_id, + ) + emitted = False + final_state: Any | None = None + async for part in self._client_instance().runs.stream( + thread_id=thread_id, + assistant_id=self._settings.agent_id, + input=run_input, + metadata=metadata, + if_not_exists="create" if thread_id is not None else None, + stream_mode=["messages", "values"], + ): + maybe_final_state = _final_state_from_stream_part(part) + if maybe_final_state is not None: + final_state = maybe_final_state + + for message in _messages_from_stream_part(part): + if not _is_assistant_message(message): + continue + text = normalize_langchain_output(dict(message)) + if not text: + continue + emitted = True + yield text + + if not emitted and final_state is not None: + fallback_text = normalize_langchain_output(final_state) + if fallback_text: + yield fallback_text + + def _client_instance(self) -> Any: + if self._client is None: + from langgraph_sdk import get_client + + client_kwargs: dict[str, Any] = {"url": self._settings.url} + if self._settings.api_key is not None: + client_kwargs["api_key"] = self._settings.api_key + self._client = get_client(**client_kwargs) + return self._client + + def _build_run_input(self, value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return value + prompt_text = extract_prompt_text(value) if isinstance(value, str | list) else normalize_langchain_output(value) + return {"messages": [{"role": "user", "content": prompt_text}]} + + def _build_metadata(self) -> dict[str, str]: + if self._langchain_context is None: + return {} + return self._langchain_context.as_metadata() + + async def _resolve_thread_id(self) -> str | None: + if not self._settings.stateful or not self._session_id: + return None + return self._default_thread_id() + + def _default_thread_id(self) -> str: + payload = f"{self._settings.url}\0{self._settings.agent_id}\0{self._session_id}" + digest = hashlib.sha256(payload.encode("utf-8")).hexdigest()[:24] + return f"bubseek-{digest}" diff --git a/contrib/bubseek-langchain/src/bubseek_langchain/config.py b/contrib/bubseek-langchain/src/bubseek_langchain/config.py index 2e1b8d3..79df108 100644 --- a/contrib/bubseek-langchain/src/bubseek_langchain/config.py +++ b/contrib/bubseek-langchain/src/bubseek_langchain/config.py @@ -2,6 +2,7 @@ from typing import Literal +from pydantic import AliasChoices, Field, ValidationError from pydantic_settings import BaseSettings, SettingsConfigDict from .errors import LangchainConfigError @@ -22,10 +23,35 @@ class LangchainPluginSettings(BaseSettings): tape: bool = True +class AgentProtocolSettings(BaseSettings): + """Configuration for the remote agent-protocol runnable adapter.""" + + model_config = SettingsConfigDict( + env_file=".env", + extra="ignore", + populate_by_name=True, + ) + + url: str = Field(validation_alias="BUB_AGENT_PROTOCOL_URL") + agent_id: str = Field(validation_alias="BUB_AGENT_PROTOCOL_AGENT_ID") + api_key: str | None = Field( + default=None, + validation_alias=AliasChoices("BUB_AGENT_PROTOCOL_API_KEY", "BUB_API_KEY"), + ) + stateful: bool = Field(default=True, validation_alias="BUB_AGENT_PROTOCOL_STATEFUL") + + def load_settings() -> LangchainPluginSettings: return LangchainPluginSettings() +def load_agent_protocol_settings() -> AgentProtocolSettings: + try: + return AgentProtocolSettings() + except ValidationError as exc: + raise LangchainConfigError(str(exc)) from exc + + def is_enabled(settings: LangchainPluginSettings) -> bool: return settings.mode == "runnable" diff --git a/contrib/bubseek-langchain/src/bubseek_langchain/normalize.py b/contrib/bubseek-langchain/src/bubseek_langchain/normalize.py index b1465a4..4de1f80 100644 --- a/contrib/bubseek-langchain/src/bubseek_langchain/normalize.py +++ b/contrib/bubseek-langchain/src/bubseek_langchain/normalize.py @@ -45,8 +45,12 @@ def _dict_to_str(data: dict[str, Any]) -> str: if isinstance(data.get("messages"), list): messages = data["messages"] if not messages: + if "values" in data: + return normalize_langchain_output(data["values"]) return "" return normalize_langchain_output(messages[-1]) + if "values" in data: + return normalize_langchain_output(data["values"]) return json.dumps(data, ensure_ascii=False, default=str) diff --git a/contrib/bubseek-langchain/tests/test_agent_protocol.py b/contrib/bubseek-langchain/tests/test_agent_protocol.py new file mode 100644 index 0000000..1d1b1ac --- /dev/null +++ b/contrib/bubseek-langchain/tests/test_agent_protocol.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +import pytest +from bubseek_langchain.agent_protocol import AgentProtocolRunnable, AgentProtocolSettings +from bubseek_langchain.bridge import LangchainRunContext +from langchain_core.runnables import Runnable + + +class _FakeRunsClient: + def __init__(self, *, wait_response: Any, stream_parts: list[Any]) -> None: + self.wait_calls: list[dict[str, Any]] = [] + self.stream_calls: list[dict[str, Any]] = [] + self._wait_response = wait_response + self._stream_parts = list(stream_parts) + + async def wait(self, **kwargs: Any) -> Any: + self.wait_calls.append(kwargs) + return self._wait_response + + async def stream(self, **kwargs: Any): + self.stream_calls.append(kwargs) + for part in self._stream_parts: + yield part + + +class _FakeClient: + def __init__(self, *, wait_response: Any, stream_parts: list[Any]) -> None: + self.runs = _FakeRunsClient(wait_response=wait_response, stream_parts=stream_parts) + + +def _run_context() -> LangchainRunContext: + return LangchainRunContext( + session_id="session-1", + tape_name="tape-x", + run_id="langchain-run-1", + ) + + +def test_agent_protocol_runnable_is_langchain_compatible() -> None: + runnable = AgentProtocolRunnable( + settings=AgentProtocolSettings(url="http://example.com", agent_id="agent"), + session_id="session-1", + langchain_context=_run_context(), + ) + + assert isinstance(runnable, Runnable) + assert callable(runnable.invoke) + assert callable(runnable.ainvoke) + assert callable(runnable.astream) + + +def test_ainvoke_uses_deterministic_thread_id_for_stateful_sessions() -> None: + fake_client = _FakeClient( + wait_response={"messages": [{"role": "assistant", "content": "remote answer"}]}, + stream_parts=[], + ) + runnable = AgentProtocolRunnable( + settings=AgentProtocolSettings(url="http://remote", agent_id="agent", stateful=True), + session_id="session-1", + langchain_context=_run_context(), + ) + runnable._client = fake_client + + first = asyncio.run(runnable.ainvoke("hello")) + second = asyncio.run(runnable.ainvoke("again")) + + assert first["messages"][-1]["content"] == "remote answer" + assert second["messages"][-1]["content"] == "remote answer" + assert len(fake_client.runs.wait_calls) == 2 + first_call = fake_client.runs.wait_calls[0] + second_call = fake_client.runs.wait_calls[1] + assert first_call["assistant_id"] == "agent" + assert first_call["input"] == {"messages": [{"role": "user", "content": "hello"}]} + assert first_call["if_not_exists"] == "create" + assert first_call["thread_id"].startswith("bubseek-") + assert second_call["thread_id"] == first_call["thread_id"] + + +def test_ainvoke_passes_dict_input_through() -> None: + fake_client = _FakeClient(wait_response={"ok": True}, stream_parts=[]) + runnable = AgentProtocolRunnable( + settings=AgentProtocolSettings(url="http://remote", agent_id="agent", stateful=False), + session_id=None, + langchain_context=_run_context(), + ) + runnable._client = fake_client + + payload = {"messages": [{"role": "user", "content": "hi"}], "context": {"mode": "fast"}} + asyncio.run(runnable.ainvoke(payload)) + + assert fake_client.runs.wait_calls[0]["thread_id"] is None + assert fake_client.runs.wait_calls[0]["input"] == payload + assert fake_client.runs.wait_calls[0]["if_not_exists"] is None + + +def test_astream_yields_assistant_message_chunks() -> None: + fake_client = _FakeClient( + wait_response=None, + stream_parts=[ + {"event": "messages/partial", "data": [{"type": "human", "content": "hello"}]}, + {"event": "messages/partial", "data": [{"type": "ai", "content": "Hel"}]}, + {"event": "messages/partial", "data": [{"type": "ai", "content": "lo"}]}, + {"event": "values", "data": {"messages": [{"type": "ai", "content": "Hello"}]}}, + ], + ) + runnable = AgentProtocolRunnable( + settings=AgentProtocolSettings(url="http://remote", agent_id="agent", stateful=False), + session_id=None, + langchain_context=_run_context(), + ) + runnable._client = fake_client + + async def _collect() -> list[str]: + return [chunk async for chunk in runnable.astream("hello")] + + chunks = asyncio.run(_collect()) + + assert chunks == ["Hel", "lo"] + assert fake_client.runs.stream_calls[0]["stream_mode"] == ["messages", "values"] + assert "version" not in fake_client.runs.stream_calls[0] + + +def test_astream_falls_back_to_final_state_when_no_message_chunks() -> None: + fake_client = _FakeClient( + wait_response=None, + stream_parts=[ + {"event": "values", "data": {"messages": [{"type": "ai", "content": "Final answer"}]}}, + ], + ) + runnable = AgentProtocolRunnable( + settings=AgentProtocolSettings(url="http://remote", agent_id="agent", stateful=False), + session_id=None, + langchain_context=_run_context(), + ) + runnable._client = fake_client + + async def _collect() -> list[str]: + return [chunk async for chunk in runnable.astream("hello")] + + chunks = asyncio.run(_collect()) + + assert chunks == ["Final answer"] + + +def test_remote_example_factory_uses_prompt_and_request_context( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + from bubseek_langchain.bridge import LangchainFactoryRequest + + from examples.langchain.remote_agent_protocol import remote_agent_protocol_agent + + monkeypatch.setenv("BUB_AGENT_PROTOCOL_URL", "http://remote") + monkeypatch.setenv("BUB_AGENT_PROTOCOL_AGENT_ID", "agent") + + request = LangchainFactoryRequest( + state={}, + session_id="session-1", + workspace=tmp_path, + tools=[], + system_prompt="system", + prompt=[{"type": "text", "text": "hello remote"}], + langchain_context=_run_context(), + ) + + binding = remote_agent_protocol_agent(request=request) + + assert binding.invoke_input == request.prompt + assert isinstance(binding.runnable, AgentProtocolRunnable) diff --git a/examples/langchain/README.md b/examples/langchain/README.md index 90b485f..588d4e4 100644 --- a/examples/langchain/README.md +++ b/examples/langchain/README.md @@ -84,3 +84,40 @@ def get_weather(city: str) -> str: ``` If `BUB_LANGCHAIN_INCLUDE_BUB_TOOLS=true`, the DeepAgents example also appends Bub-bridged tools to its tool list. + +## Remote Agent Protocol + +Factory path: + +```bash +examples.langchain.remote_agent_protocol:remote_agent_protocol_agent +``` + +Enable it: + +```bash +export BUB_LANGCHAIN_MODE=runnable +export BUB_LANGCHAIN_FACTORY=examples.langchain.remote_agent_protocol:remote_agent_protocol_agent +export BUB_AGENT_PROTOCOL_URL=http://localhost:2024 +export BUB_AGENT_PROTOCOL_AGENT_ID=agent +``` + +Optional override: + +```bash +export BUB_AGENT_PROTOCOL_API_KEY=your-api-key +export BUB_AGENT_PROTOCOL_STATEFUL=true +``` + +Run it: + +```bash +uv run bub chat +uv run bub run "Summarize this workspace in one sentence." +``` + +Notes: + +- `BUB_AGENT_PROTOCOL_STATEFUL=true` maps each Bub session to a deterministic protocol `thread_id`. +- The adapter uses `langgraph_sdk.get_client()` as transport, but only relies on the standard `agent_id`, `thread_id`, `messages`, `values`, and `stream_mode` subset. +- Remote tool execution remains owned by the server-side assistant; local Bub tools are not forwarded into the remote runtime. diff --git a/examples/langchain/remote_agent_protocol.py b/examples/langchain/remote_agent_protocol.py new file mode 100644 index 0000000..e58f0d8 --- /dev/null +++ b/examples/langchain/remote_agent_protocol.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from bubseek_langchain import AgentProtocolRunnable, RunnableBinding, load_agent_protocol_settings +from bubseek_langchain.bridge import LangchainFactoryRequest + + +def remote_agent_protocol_agent( + *, + request: LangchainFactoryRequest, +) -> RunnableBinding: + """Build a RunnableBinding backed by a remote agent-protocol server.""" + + runnable = AgentProtocolRunnable( + settings=load_agent_protocol_settings(), + session_id=request.session_id, + langchain_context=request.langchain_context, + ) + return RunnableBinding( + runnable=runnable, + invoke_input=request.prompt, + ) diff --git a/uv.lock b/uv.lock index 049355a..406fff0 100644 --- a/uv.lock +++ b/uv.lock @@ -495,6 +495,7 @@ source = { editable = "contrib/bubseek-langchain" } dependencies = [ { name = "bub" }, { name = "langchain-core" }, + { name = "langgraph-sdk" }, { name = "pydantic" }, { name = "pydantic-settings" }, ] @@ -511,6 +512,7 @@ requires-dist = [ { name = "deepagents", marker = "extra == 'deepagents'", specifier = ">=0.5.3" }, { name = "langchain-core", specifier = ">=0.3.0" }, { name = "langchain-openai", marker = "extra == 'deepagents'", specifier = ">=0.3.0" }, + { name = "langgraph-sdk", specifier = ">=0.3.13" }, { name = "pydantic", specifier = ">=2.0" }, { name = "pydantic-settings", specifier = ">=2.0.0" }, ] From c00979f36755a6379f5b654b118c7e2f591a7044 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Tue, 21 Apr 2026 17:21:35 +0800 Subject: [PATCH 2/3] chore: make agent protocol runable works Signed-off-by: Chojan Shang --- .../src/bubseek_langchain/agent_protocol.py | 148 ++++++++++++++---- .../tests/test_agent_protocol.py | 106 ++++++++++++- examples/langchain/remote_agent_protocol.py | 35 +++++ 3 files changed, 252 insertions(+), 37 deletions(-) diff --git a/contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py b/contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py index 40524fe..bbe9848 100644 --- a/contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py +++ b/contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py @@ -2,6 +2,7 @@ import asyncio import hashlib +import json from collections.abc import AsyncIterator, Mapping from typing import Any @@ -12,6 +13,8 @@ from .config import AgentProtocolSettings from .normalize import normalize_langchain_output +INTERRUPT_KEY = "__interrupt__" + def _bind_logger(run_context: LangchainRunContext | None): if run_context is None: @@ -34,39 +37,95 @@ def _is_assistant_message(message: Mapping[str, Any]) -> bool: return role in {"assistant", "ai", "aimessage", "aimessagechunk"} -def _messages_from_stream_part(part: Any) -> list[Mapping[str, Any]]: - if hasattr(part, "event") and hasattr(part, "data"): +def _stream_event_name(part: Any) -> str | None: + if hasattr(part, "event"): event = part.event - data = part.data - if event == "messages" and isinstance(data, list) and data: - first = data[0] - return [first] if isinstance(first, Mapping) else [] - if event in {"messages/partial", "messages/complete"} and isinstance(data, list): - return [item for item in data if isinstance(item, Mapping)] - return [] - - if not isinstance(part, dict): - return [] - part_type = part.get("type") or part.get("event") - data = part.get("data") - if part_type == "messages" and isinstance(data, list) and data: - first = data[0] - return [first] if isinstance(first, Mapping) else [] - if part_type in {"messages/partial", "messages/complete"} and isinstance(data, list): - return [item for item in data if isinstance(item, Mapping)] - return [] + return event if isinstance(event, str) else None + if isinstance(part, dict): + event = part.get("event") or part.get("type") + return event if isinstance(event, str) else None + return None -def _final_state_from_stream_part(part: Any) -> Any | None: - if hasattr(part, "event") and hasattr(part, "data"): - return part.data if part.event == "values" else None - if isinstance(part, dict) and (part.get("type") == "values" or part.get("event") == "values"): +def _stream_event_data(part: Any) -> Any: + if hasattr(part, "data"): + return part.data + if isinstance(part, dict): return part.get("data") return None +def _interrupts_from_stream_part(part: Any) -> list[Any]: + if isinstance(part, dict): + interrupts = part.get("interrupts") + if isinstance(interrupts, list) and interrupts: + return interrupts + + data = _stream_event_data(part) + if isinstance(data, Mapping): + interrupts = data.get(INTERRUPT_KEY) + if isinstance(interrupts, list) and interrupts: + return interrupts + return [] + + +def _raise_for_stream_part(part: Any) -> None: + event = _stream_event_name(part) + if event is None: + return + + if event.startswith("error"): + detail = normalize_langchain_output(_stream_event_data(part)) + message = detail or "Remote agent-protocol run failed" + raise AgentProtocolRemoteError(message) + + interrupts = _interrupts_from_stream_part(part) + if interrupts and (event == "values" or event.startswith("updates")): + raise AgentProtocolInterruptedError( + f"Remote agent-protocol run interrupted: {json.dumps(interrupts, ensure_ascii=False, default=str)}" + ) + + +def _messages_from_stream_part(part: Any) -> tuple[str | None, list[Mapping[str, Any]]]: + event = _stream_event_name(part) + data = _stream_event_data(part) + if event is None: + return None, [] + + if event == "messages" and isinstance(data, list) and data: + first = data[0] + if isinstance(first, Mapping): + return event, [first] + return event, [] + + if event in {"messages/partial", "messages/complete"} and isinstance(data, list): + return event, [item for item in data if isinstance(item, Mapping)] + + return event, [] + + +def _text_from_message(message: Mapping[str, Any]) -> str: + return normalize_langchain_output(dict(message)) + + +def _final_state_from_stream_part(part: Any) -> Any | None: + return _stream_event_data(part) if _stream_event_name(part) == "values" else None + + +class AgentProtocolRemoteError(RuntimeError): + """Raised when the remote agent-protocol run returns an explicit error event.""" + + +class AgentProtocolInterruptedError(RuntimeError): + """Raised when the remote agent-protocol run reports an interrupt.""" + + class AgentProtocolRunnable(Runnable[Any, Any]): - """Wrap a remote agent-protocol server as a LangChain-compatible Runnable.""" + """Wrap a remote Bub agent-protocol endpoint as a Bub-oriented Runnable. + + This adapter intentionally accepts Bub prompt shapes or a fully-formed input + dict. It does not implement general Pregel or RemoteGraph config semantics. + """ def __init__( self, @@ -91,7 +150,7 @@ def invoke(self, value: Any, config: dict[str, Any] | None = None, **kwargs: Any async def ainvoke(self, value: Any, config: dict[str, Any] | None = None, **kwargs: Any) -> Any: thread_id = await self._resolve_thread_id() run_input = self._build_run_input(value) - metadata = self._build_metadata() + metadata = self._build_metadata(config) self._logger.debug( "Invoking remote agent-protocol agent={} stateful={} thread_id={}", self._settings.agent_id, @@ -109,7 +168,7 @@ async def ainvoke(self, value: Any, config: dict[str, Any] | None = None, **kwar async def astream(self, value: Any, config: dict[str, Any] | None = None, **kwargs: Any) -> AsyncIterator[str]: thread_id = await self._resolve_thread_id() run_input = self._build_run_input(value) - metadata = self._build_metadata() + metadata = self._build_metadata(config) self._logger.debug( "Streaming remote agent-protocol agent={} stateful={} thread_id={}", self._settings.agent_id, @@ -117,6 +176,7 @@ async def astream(self, value: Any, config: dict[str, Any] | None = None, **kwar thread_id, ) emitted = False + saw_partial_message = False final_state: Any | None = None async for part in self._client_instance().runs.stream( thread_id=thread_id, @@ -124,16 +184,24 @@ async def astream(self, value: Any, config: dict[str, Any] | None = None, **kwar input=run_input, metadata=metadata, if_not_exists="create" if thread_id is not None else None, - stream_mode=["messages", "values"], + stream_mode=["messages", "values", "updates"], ): + _raise_for_stream_part(part) + maybe_final_state = _final_state_from_stream_part(part) if maybe_final_state is not None: final_state = maybe_final_state - for message in _messages_from_stream_part(part): + event, messages = _messages_from_stream_part(part) + if event == "messages/partial": + saw_partial_message = True + if event == "messages/complete" and saw_partial_message: + continue + + for message in messages: if not _is_assistant_message(message): continue - text = normalize_langchain_output(dict(message)) + text = _text_from_message(message) if not text: continue emitted = True @@ -160,10 +228,22 @@ def _build_run_input(self, value: Any) -> dict[str, Any]: prompt_text = extract_prompt_text(value) if isinstance(value, str | list) else normalize_langchain_output(value) return {"messages": [{"role": "user", "content": prompt_text}]} - def _build_metadata(self) -> dict[str, str]: - if self._langchain_context is None: - return {} - return self._langchain_context.as_metadata() + def _build_metadata(self, config: Mapping[str, Any] | None) -> dict[str, Any]: + metadata: dict[str, Any] = {} + if self._langchain_context is not None: + metadata.update(self._langchain_context.as_metadata()) + + if not isinstance(config, Mapping): + return metadata + + config_metadata = config.get("metadata") + if not isinstance(config_metadata, Mapping): + return metadata + + for key, value in config_metadata.items(): + if isinstance(key, str): + metadata[key] = value + return metadata async def _resolve_thread_id(self) -> str | None: if not self._settings.stateful or not self._session_id: diff --git a/contrib/bubseek-langchain/tests/test_agent_protocol.py b/contrib/bubseek-langchain/tests/test_agent_protocol.py index 1d1b1ac..7af8f4d 100644 --- a/contrib/bubseek-langchain/tests/test_agent_protocol.py +++ b/contrib/bubseek-langchain/tests/test_agent_protocol.py @@ -5,7 +5,12 @@ from typing import Any import pytest -from bubseek_langchain.agent_protocol import AgentProtocolRunnable, AgentProtocolSettings +from bubseek_langchain.agent_protocol import ( + AgentProtocolInterruptedError, + AgentProtocolRemoteError, + AgentProtocolRunnable, + AgentProtocolSettings, +) from bubseek_langchain.bridge import LangchainRunContext from langchain_core.runnables import Runnable @@ -97,6 +102,25 @@ def test_ainvoke_passes_dict_input_through() -> None: assert fake_client.runs.wait_calls[0]["if_not_exists"] is None +def test_ainvoke_merges_config_metadata() -> None: + fake_client = _FakeClient(wait_response={"ok": True}, stream_parts=[]) + runnable = AgentProtocolRunnable( + settings=AgentProtocolSettings(url="http://remote", agent_id="agent", stateful=False), + session_id=None, + langchain_context=_run_context(), + ) + runnable._client = fake_client + + asyncio.run(runnable.ainvoke("hello", config={"metadata": {"source": "test"}})) + + assert fake_client.runs.wait_calls[0]["metadata"] == { + "session_id": "session-1", + "langchain_run_id": "langchain-run-1", + "tape_name": "tape-x", + "source": "test", + } + + def test_astream_yields_assistant_message_chunks() -> None: fake_client = _FakeClient( wait_response=None, @@ -120,10 +144,32 @@ async def _collect() -> list[str]: chunks = asyncio.run(_collect()) assert chunks == ["Hel", "lo"] - assert fake_client.runs.stream_calls[0]["stream_mode"] == ["messages", "values"] + assert fake_client.runs.stream_calls[0]["stream_mode"] == ["messages", "values", "updates"] assert "version" not in fake_client.runs.stream_calls[0] +def test_astream_does_not_duplicate_complete_message_after_partials() -> None: + fake_client = _FakeClient( + wait_response=None, + stream_parts=[ + {"event": "messages/partial", "data": [{"type": "ai", "content": "Hel"}]}, + {"event": "messages/partial", "data": [{"type": "ai", "content": "lo"}]}, + {"event": "messages/complete", "data": [{"type": "ai", "content": "Hello"}]}, + ], + ) + runnable = AgentProtocolRunnable( + settings=AgentProtocolSettings(url="http://remote", agent_id="agent", stateful=False), + session_id=None, + langchain_context=_run_context(), + ) + runnable._client = fake_client + + async def _collect() -> list[str]: + return [chunk async for chunk in runnable.astream("hello")] + + assert asyncio.run(_collect()) == ["Hel", "lo"] + + def test_astream_falls_back_to_final_state_when_no_message_chunks() -> None: fake_client = _FakeClient( wait_response=None, @@ -146,13 +192,55 @@ async def _collect() -> list[str]: assert chunks == ["Final answer"] +def test_astream_raises_on_remote_error_event() -> None: + fake_client = _FakeClient( + wait_response=None, + stream_parts=[ + {"event": "error", "data": {"message": "boom"}}, + ], + ) + runnable = AgentProtocolRunnable( + settings=AgentProtocolSettings(url="http://remote", agent_id="agent", stateful=False), + session_id=None, + langchain_context=_run_context(), + ) + runnable._client = fake_client + + async def _collect() -> list[str]: + return [chunk async for chunk in runnable.astream("hello")] + + with pytest.raises(AgentProtocolRemoteError, match="boom"): + asyncio.run(_collect()) + + +def test_astream_raises_on_interrupt_update_event() -> None: + fake_client = _FakeClient( + wait_response=None, + stream_parts=[ + {"event": "updates", "data": {"__interrupt__": [{"value": "wait"}]}}, + ], + ) + runnable = AgentProtocolRunnable( + settings=AgentProtocolSettings(url="http://remote", agent_id="agent", stateful=False), + session_id=None, + langchain_context=_run_context(), + ) + runnable._client = fake_client + + async def _collect() -> list[str]: + return [chunk async for chunk in runnable.astream("hello")] + + with pytest.raises(AgentProtocolInterruptedError, match="interrupted"): + asyncio.run(_collect()) + + def test_remote_example_factory_uses_prompt_and_request_context( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: from bubseek_langchain.bridge import LangchainFactoryRequest - from examples.langchain.remote_agent_protocol import remote_agent_protocol_agent + from examples.langchain.remote_agent_protocol import _parse_remote_agent_output, remote_agent_protocol_agent monkeypatch.setenv("BUB_AGENT_PROTOCOL_URL", "http://remote") monkeypatch.setenv("BUB_AGENT_PROTOCOL_AGENT_ID", "agent") @@ -171,3 +259,15 @@ def test_remote_example_factory_uses_prompt_and_request_context( assert binding.invoke_input == request.prompt assert isinstance(binding.runnable, AgentProtocolRunnable) + assert binding.output_parser is _parse_remote_agent_output + + +def test_remote_output_parser_extracts_visible_text_blocks() -> None: + from examples.langchain.remote_agent_protocol import _parse_remote_agent_output + + payload = ( + '[{"signature":"","thinking":"internal","type":"thinking"},' + '{"text":"Visible answer","type":"text"}]' + ) + + assert _parse_remote_agent_output(payload) == "Visible answer" diff --git a/examples/langchain/remote_agent_protocol.py b/examples/langchain/remote_agent_protocol.py index e58f0d8..a98ea3b 100644 --- a/examples/langchain/remote_agent_protocol.py +++ b/examples/langchain/remote_agent_protocol.py @@ -1,7 +1,41 @@ from __future__ import annotations +import json +from typing import Any + from bubseek_langchain import AgentProtocolRunnable, RunnableBinding, load_agent_protocol_settings from bubseek_langchain.bridge import LangchainFactoryRequest +from bubseek_langchain.normalize import normalize_langchain_output + + +def _extract_visible_text_blocks(payload: Any) -> str: + if isinstance(payload, dict): + text = payload.get("text") + return text if isinstance(text, str) else "" + if not isinstance(payload, list): + return "" + + parts = [ + text + for item in payload + if isinstance(item, dict) and isinstance((text := item.get("text")), str) and text.strip() + ] + return "\n".join(parts) + + +def _parse_remote_agent_output(value: Any) -> str: + text = normalize_langchain_output(value) + stripped = text.strip() + if not stripped or stripped[0] not in "[{": + return text + + try: + payload = json.loads(stripped) + except json.JSONDecodeError: + return text + + visible_text = _extract_visible_text_blocks(payload) + return visible_text or text def remote_agent_protocol_agent( @@ -18,4 +52,5 @@ def remote_agent_protocol_agent( return RunnableBinding( runnable=runnable, invoke_input=request.prompt, + output_parser=_parse_remote_agent_output, ) From 6b5487250d8c0c6542a73269468417d8a84886c4 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Tue, 21 Apr 2026 17:39:04 +0800 Subject: [PATCH 3/3] chore: make check happy Signed-off-by: Chojan Shang --- .../src/bubseek_langchain/agent_protocol.py | 19 ++++++++++++------- .../src/bubseek_langchain/config.py | 5 +++-- .../tests/test_agent_protocol.py | 5 +---- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py b/contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py index bbe9848..0de398f 100644 --- a/contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py +++ b/contrib/bubseek-langchain/src/bubseek_langchain/agent_protocol.py @@ -6,7 +6,7 @@ from collections.abc import AsyncIterator, Mapping from typing import Any -from langchain_core.runnables import Runnable +from langchain_core.runnables import Runnable, RunnableConfig from loguru import logger from .bridge import LangchainRunContext, extract_prompt_text @@ -140,16 +140,16 @@ def __init__( self._logger = _bind_logger(langchain_context) self._client: Any | None = None - def invoke(self, value: Any, config: dict[str, Any] | None = None, **kwargs: Any) -> Any: + def invoke(self, input: Any, config: RunnableConfig | None = None, **kwargs: Any) -> Any: # noqa: A002 try: asyncio.get_running_loop() except RuntimeError: - return asyncio.run(self.ainvoke(value, config=config, **kwargs)) + return asyncio.run(self.ainvoke(input, config=config, **kwargs)) raise RuntimeError("AgentProtocolRunnable.invoke cannot be used from a running event loop; use ainvoke instead") - async def ainvoke(self, value: Any, config: dict[str, Any] | None = None, **kwargs: Any) -> Any: + async def ainvoke(self, input: Any, config: RunnableConfig | None = None, **kwargs: Any) -> Any: # noqa: A002 thread_id = await self._resolve_thread_id() - run_input = self._build_run_input(value) + run_input = self._build_run_input(input) metadata = self._build_metadata(config) self._logger.debug( "Invoking remote agent-protocol agent={} stateful={} thread_id={}", @@ -165,9 +165,14 @@ async def ainvoke(self, value: Any, config: dict[str, Any] | None = None, **kwar if_not_exists="create" if thread_id is not None else None, ) - async def astream(self, value: Any, config: dict[str, Any] | None = None, **kwargs: Any) -> AsyncIterator[str]: + async def astream( + self, + input: Any, # noqa: A002 + config: RunnableConfig | None = None, + **kwargs: Any | None, + ) -> AsyncIterator[str]: thread_id = await self._resolve_thread_id() - run_input = self._build_run_input(value) + run_input = self._build_run_input(input) metadata = self._build_metadata(config) self._logger.debug( "Streaming remote agent-protocol agent={} stateful={} thread_id={}", diff --git a/contrib/bubseek-langchain/src/bubseek_langchain/config.py b/contrib/bubseek-langchain/src/bubseek_langchain/config.py index 79df108..dc03bb2 100644 --- a/contrib/bubseek-langchain/src/bubseek_langchain/config.py +++ b/contrib/bubseek-langchain/src/bubseek_langchain/config.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal, cast from pydantic import AliasChoices, Field, ValidationError from pydantic_settings import BaseSettings, SettingsConfigDict @@ -47,7 +47,8 @@ def load_settings() -> LangchainPluginSettings: def load_agent_protocol_settings() -> AgentProtocolSettings: try: - return AgentProtocolSettings() + settings_cls = cast(Any, AgentProtocolSettings) + return cast(AgentProtocolSettings, settings_cls()) except ValidationError as exc: raise LangchainConfigError(str(exc)) from exc diff --git a/contrib/bubseek-langchain/tests/test_agent_protocol.py b/contrib/bubseek-langchain/tests/test_agent_protocol.py index 7af8f4d..72910c4 100644 --- a/contrib/bubseek-langchain/tests/test_agent_protocol.py +++ b/contrib/bubseek-langchain/tests/test_agent_protocol.py @@ -265,9 +265,6 @@ def test_remote_example_factory_uses_prompt_and_request_context( def test_remote_output_parser_extracts_visible_text_blocks() -> None: from examples.langchain.remote_agent_protocol import _parse_remote_agent_output - payload = ( - '[{"signature":"","thinking":"internal","type":"thinking"},' - '{"text":"Visible answer","type":"text"}]' - ) + payload = '[{"signature":"","thinking":"internal","type":"thinking"},{"text":"Visible answer","type":"text"}]' assert _parse_remote_agent_output(payload) == "Visible answer"