Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 147 additions & 7 deletions livekit-agents/livekit/agents/llm/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import asyncio
import json
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from collections.abc import Awaitable, Callable, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
from dataclasses import dataclass
from datetime import timedelta
from pathlib import Path
from typing import Any, Literal
from urllib.parse import urlparse

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from typing_extensions import Self

Expand Down Expand Up @@ -45,6 +46,10 @@
)

MCPTool = RawFunctionTool
_ToolListChangedCallback = Callable[[], None]
_ReloadFailedCallback = Callable[[BaseException], None]

_RELOAD_RETRY_DELAYS: tuple[float, ...] = (1.0, 2.0, 4.0)


@dataclass
Expand Down Expand Up @@ -86,7 +91,10 @@ def __init__(
)

self._cache_dirty = True
self._cache_generation = 0
self._lk_tools: list[MCPTool] | None = None
self._tool_list_changed_callbacks: set[_ToolListChangedCallback] = set()
self._reload_failed_callbacks: set[_ReloadFailedCallback] = set()

self._client_task: asyncio.Task[None] | None = None
self._closing_ev = asyncio.Event()
Expand All @@ -98,6 +106,7 @@ def initialized(self) -> bool:

def invalidate_cache(self) -> None:
self._cache_dirty = True
self._cache_generation += 1

async def initialize(self) -> None:
if self._client_task and not self._client_task.done():
Expand All @@ -122,6 +131,7 @@ async def _run_client(self, ready_fut: asyncio.Future[None]) -> None:
read_timeout_seconds=timedelta(seconds=self._read_timeout)
if self._read_timeout
else None,
message_handler=self._handle_message,
) as client:
await client.initialize()
self._client = client
Expand All @@ -140,21 +150,65 @@ async def _run_client(self, ready_fut: asyncio.Future[None]) -> None:
self._lk_tools = None
self._closing_ev.clear()

async def _handle_message(self, message: Any) -> None:
# Always yield to the loop, matching the MCP SDK default handler. Without
# this, a burst of notifications starves the receive loop of cancellation
# points.
await anyio.lowlevel.checkpoint()

if isinstance(message, Exception):
return

if isinstance(message, mcp.types.ServerNotification) and isinstance(
message.root, mcp.types.ToolListChangedNotification
):
self._handle_tool_list_changed()

def _handle_tool_list_changed(self) -> None:
self.invalidate_cache()
# Refetch from toolset tasks instead of the MCP receive loop, which must stay free
# to deliver the list_tools() response.
for callback in list(self._tool_list_changed_callbacks):
try:
callback()
except Exception:
logger.exception("error in MCP tool list changed callback")

def _add_tool_list_changed_callback(self, callback: _ToolListChangedCallback) -> None:
self._tool_list_changed_callbacks.add(callback)

def _remove_tool_list_changed_callback(self, callback: _ToolListChangedCallback) -> None:
self._tool_list_changed_callbacks.discard(callback)

def _add_reload_failed_callback(self, callback: _ReloadFailedCallback) -> None:
self._reload_failed_callbacks.add(callback)

def _remove_reload_failed_callback(self, callback: _ReloadFailedCallback) -> None:
self._reload_failed_callbacks.discard(callback)

def _notify_reload_failed(self, error: BaseException) -> None:
for callback in list(self._reload_failed_callbacks):
try:
callback(error)
except Exception:
logger.exception("error in MCP reload failed callback")

async def list_tools(self) -> list[MCPTool]:
if self._client is None:
raise RuntimeError("MCPServer isn't initialized")

if not self._cache_dirty and self._lk_tools is not None:
return self._lk_tools

cache_generation = self._cache_generation
tools = await self._client.list_tools()
lk_tools = [
self._make_function_tool(tool.name, tool.description, tool.inputSchema, tool.meta)
for tool in tools.tools
]

self._lk_tools = lk_tools
self._cache_dirty = False
self._cache_dirty = cache_generation != self._cache_generation
return lk_tools

def _make_function_tool(
Expand Down Expand Up @@ -444,6 +498,10 @@ def __init__(self, *, id: str, mcp_server: MCPServer) -> None:
self._mcp_server = mcp_server
self._initialized = False
self._lock = asyncio.Lock()
self._reload_requested = False
self._reload_task: asyncio.Task[None] | None = None
self._listening_for_tool_changes = False
self._tool_filter: Callable[[MCPTool], bool] | None = None

async def setup(self, *, reload: bool = False) -> Self:
"""Initialize the MCP server connection and fetch available tools.
Expand All @@ -461,27 +519,109 @@ async def setup(self, *, reload: bool = False) -> Self:
if not reload and self._initialized:
return self

# Register before the first list_tools() call so a concurrent list_changed
# notification queues a reload instead of leaving the initial result stale.
if not self._listening_for_tool_changes:
self._mcp_server._add_tool_list_changed_callback(self._request_tools_reload)
self._listening_for_tool_changes = True

if not self._mcp_server.initialized:
await self._mcp_server.initialize()
elif reload:
self._mcp_server.invalidate_cache()

tools = await self._mcp_server.list_tools()
self._tools = tools
self._tools = self._apply_filter(tools)
self._initialized = True
return self

def filter_tools(self, filter_fn: Callable[[MCPTool], bool]) -> Self:
"""Filter the toolset's tools in-place using a predicate."""
self._tools = [
tool for tool in self._tools if isinstance(tool, MCPTool) and filter_fn(tool)
]
if self._tool_filter is None:
self._tool_filter = filter_fn
else:
previous_filter = self._tool_filter
self._tool_filter = lambda tool: previous_filter(tool) and filter_fn(tool)

self._tools = self._apply_filter(
[tool for tool in self._tools if isinstance(tool, MCPTool)]
)
return self

def _apply_filter(self, tools: Sequence[MCPTool]) -> list[MCPTool]:
if self._tool_filter is None:
return list(tools)

return [tool for tool in tools if self._tool_filter(tool)]

def _request_tools_reload(self) -> None:
# Guard against snapshot-iteration races: MCPServer._handle_tool_list_changed
# snapshots its callback set, then fires each. If aclose() unsubscribed us
# between the snapshot and this call, the listening flag is already False
# and we must not spawn a reload task against a server that may be tearing
# down.
if not self._listening_for_tool_changes:
return

self._reload_requested = True
if self._reload_task is not None and not self._reload_task.done():
return

self._reload_task = asyncio.create_task(
self._reload_tools(), name=f"{type(self).__name__}._reload_tools"
)

async def _reload_tools(self) -> None:
try:
while self._reload_requested:
self._reload_requested = False
await self._reload_tools_with_retry()
except asyncio.CancelledError:
raise

async def _reload_tools_with_retry(self) -> None:
last_exc: BaseException | None = None
for attempt, delay in enumerate((0.0,) + _RELOAD_RETRY_DELAYS):
if delay:
try:
await asyncio.sleep(delay)
except asyncio.CancelledError:
raise
try:
await self.setup(reload=True)
self._notify_tools_changed()
return
except asyncio.CancelledError:
raise
except Exception as exc:
last_exc = exc
logger.warning(
"failed to reload MCP tools (attempt %d/%d): %s",
attempt + 1,
len(_RELOAD_RETRY_DELAYS) + 1,
exc,
)

assert last_exc is not None
logger.exception(
"giving up reloading MCP tools after tools/list_changed notification",
exc_info=last_exc,
)
self._mcp_server._notify_reload_failed(last_exc)

async def aclose(self) -> None:
try:
if self._listening_for_tool_changes:
self._mcp_server._remove_tool_list_changed_callback(self._request_tools_reload)
self._listening_for_tool_changes = False
if self._reload_task is not None and self._reload_task is not asyncio.current_task():
self._reload_task.cancel()
with suppress(asyncio.CancelledError):
await self._reload_task

await super().aclose()
await self._mcp_server.aclose()
finally:
self._initialized = False
self._tools = []
self._reload_requested = False
14 changes: 14 additions & 0 deletions livekit-agents/livekit/agents/llm/tool_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, *, id: str, tools: Sequence[Tool | Toolset] | None = None) ->
self._id = id
self._tools: Sequence[Tool | Toolset] = list(tools) if tools is not None else []
self._tools.extend(find_function_tools(self))
self._tools_changed_callbacks: set[Callable[[Toolset], None]] = set()

@property
def id(self) -> str:
Expand Down Expand Up @@ -94,6 +95,19 @@ async def aclose(self) -> None:
if toolsets:
await asyncio.gather(*(toolset.aclose() for toolset in toolsets))

def _add_tools_changed_callback(self, callback: Callable[[Toolset], None]) -> None:
self._tools_changed_callbacks.add(callback)

def _remove_tools_changed_callback(self, callback: Callable[[Toolset], None]) -> None:
self._tools_changed_callbacks.discard(callback)

def _notify_tools_changed(self) -> None:
for callback in list(self._tools_changed_callbacks):
try:
callback(self)
except Exception:
logger.exception("error in tools_changed callback")


# Used by ToolChoice
class Function(TypedDict, total=False):
Expand Down
Loading