diff --git a/temporalio/contrib/google_adk_agents/_mcp.py b/temporalio/contrib/google_adk_agents/_mcp.py index 6c6123806..213f2822d 100644 --- a/temporalio/contrib/google_adk_agents/_mcp.py +++ b/temporalio/contrib/google_adk_agents/_mcp.py @@ -90,7 +90,9 @@ class TemporalMcpToolSetProvider: within Temporal workflows. """ - def __init__(self, name: str, toolset_factory: Callable[[Any | None], McpToolset]): + def __init__( + self, name: str, toolset_factory: Callable[[Any | None], McpToolset] + ) -> None: """Initializes the toolset provider. Args: @@ -215,6 +217,7 @@ def __init__( name: str, config: ActivityConfig | None = None, factory_argument: Any | None = None, + local_toolset: Callable[[Any | None], McpToolset] | None = None, ): """Initializes the Temporal MCP toolset. @@ -222,6 +225,7 @@ def __init__( name: Name of the toolset (used for activity naming). config: Optional activity configuration. factory_argument: Optional argument passed to toolset factory. + local_toolset: Optional factory for a temporal toolset for local execution when running outside a durable workflow. """ super().__init__() self._name = name @@ -229,6 +233,7 @@ def __init__( self._config = config or ActivityConfig( start_to_close_timeout=timedelta(minutes=1) ) + self._local_toolset = local_toolset async def get_tools( self, readonly_context: ReadonlyContext | None = None @@ -241,6 +246,14 @@ async def get_tools( Returns: List of available tools wrapped as Temporal activities. """ + # If executed outside a workflow, like when doing local adk runs, use the mcp server directly + if not workflow.in_workflow(): + if self._local_toolset is None: + raise ValueError( + "Attempted to execute an MCP tool declared with TemporalMcpToolSet outside of a Workflow. Either use McpToolSet or pass a copy of your MCP toolset provider into local_toolset." + ) + return await self._local_toolset(None).get_tools(readonly_context) + tool_results: list[_ToolResult] = await workflow.execute_activity( self._name + "-list-tools", _GetToolsArguments(self._factory_argument), diff --git a/temporalio/contrib/google_adk_agents/_model.py b/temporalio/contrib/google_adk_agents/_model.py index 80079433c..6d1e7ffa9 100644 --- a/temporalio/contrib/google_adk_agents/_model.py +++ b/temporalio/contrib/google_adk_agents/_model.py @@ -5,6 +5,7 @@ from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse +import temporalio.workflow from temporalio import activity, workflow from temporalio.workflow import ActivityConfig @@ -67,6 +68,14 @@ async def generate_content_async( Yields: The responses from the model. """ + # If executed outside a workflow, like when doing local adk runs, use the model directly + if not temporalio.workflow.in_workflow(): + async for response in LLMRegistry.new_llm( + self._model_name + ).generate_content_async(llm_request, stream=stream): + yield response + return + responses = await workflow.execute_activity( invoke_model, args=[llm_request], diff --git a/temporalio/contrib/google_adk_agents/workflow.py b/temporalio/contrib/google_adk_agents/workflow.py index 42ff7246f..93815aaba 100644 --- a/temporalio/contrib/google_adk_agents/workflow.py +++ b/temporalio/contrib/google_adk_agents/workflow.py @@ -3,6 +3,7 @@ import inspect from typing import Any, Callable +import temporalio.workflow from temporalio import workflow @@ -29,6 +30,14 @@ async def wrapper(*args: Any, **kw: Any): # Decorator kwargs are defaults. options = kwargs.copy() + if not temporalio.workflow.in_workflow(): + # If executed outside a workflow, like when doing local adk runs, use the function directly + result = activity_def(*args, **kw) + if inspect.isawaitable(result): + return await result + else: + return result + return await workflow.execute_activity(activity_def, *activity_args, **options) # Copy metadata diff --git a/tests/contrib/google_adk_agents/test_google_adk_agents.py b/tests/contrib/google_adk_agents/test_google_adk_agents.py index 4d41b6a82..5d986236c 100644 --- a/tests/contrib/google_adk_agents/test_google_adk_agents.py +++ b/tests/contrib/google_adk_agents/test_google_adk_agents.py @@ -18,8 +18,9 @@ import os import uuid from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Iterator +from collections.abc import AsyncGenerator from datetime import timedelta +from typing import Any import pytest from google.adk import Agent, Runner @@ -64,6 +65,19 @@ async def get_weather(city: str) -> str: # type: ignore[reportUnusedParameter] return "Warm and sunny. 17 degrees." +def weather_agent(model_name: str) -> Agent: + # Wraps 'get_weather' activity as a Tool + weather_tool = temporalio.contrib.google_adk_agents.workflow.activity_tool( + get_weather, start_to_close_timeout=timedelta(seconds=60) + ) + + return Agent( + name="test_agent", + model=TemporalModel(model_name), + tools=[weather_tool], + ) + + @workflow.defn class WeatherAgent: @workflow.run @@ -73,17 +87,7 @@ async def run(self, prompt: str, model_name: str) -> Event | None: # 1. Define Agent using Temporal Helpers # Note: AgentPlugin in the Runner automatically handles Runtime setup # and Model Activity interception. We use standard ADK models now. - - # Wraps 'get_weather' activity as a Tool - weather_tool = temporalio.contrib.google_adk_agents.workflow.activity_tool( - get_weather, start_to_close_timeout=timedelta(seconds=60) - ) - - agent = Agent( - name="test_agent", - model=TemporalModel(model_name), - tools=[weather_tool], - ) + agent = weather_agent(model_name) # 2. Create runner runner = InMemoryRunner( @@ -357,6 +361,30 @@ async def test_multi_agent(client: Client, use_local_model: bool): assert result == "haiku" +def example_toolset(_: Any | None) -> McpToolset: + return McpToolset( + connection_params=StdioConnectionParams( + server_params=StdioServerParameters( + command="npx", + args=[ + "-y", + "@modelcontextprotocol/server-filesystem", + os.path.dirname(os.path.abspath(__file__)), + ], + ), + ), + ) + + +def mcp_agent(model_name: str) -> Agent: + return Agent( + name="test_agent", + # instruction="Always use your tools to answer questions.", + model=TemporalModel(model_name), + tools=[TemporalMcpToolSet("test_set", local_toolset=example_toolset)], + ) + + @workflow.defn class McpAgent: @workflow.run @@ -364,14 +392,7 @@ async def run(self, prompt: str, model_name: str) -> str: logger.info("Workflow started.") # 1. Define Agent using Temporal Helpers - # Note: AgentPlugin in the Runner automatically handles Runtime setup - # and Model Activity interception. We use standard ADK models now. - agent = Agent( - name="test_agent", - # instruction="Always use your tools to answer questions.", - model=TemporalModel(model_name), - tools=[TemporalMcpToolSet("test_set")], - ) + agent = mcp_agent(model_name) # 2. Create Session (uses runtime.new_uuid() -> workflow.uuid4()) session_service = InMemorySessionService() @@ -408,39 +429,36 @@ async def run(self, prompt: str, model_name: str) -> str: return last_event.content.parts[0].text -class McpModel(BaseLlm): - responses: list[LlmResponse] = [ - LlmResponse( - content=Content( - role="model", - parts=[ - Part( - function_call=FunctionCall( - args={"path": os.path.dirname(os.path.abspath(__file__))}, - name="list_directory", +class McpModel(TestModel): + def responses(self) -> list[LlmResponse]: + return [ + LlmResponse( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + args={ + "path": os.path.dirname(os.path.abspath(__file__)) + }, + name="list_directory", + ) ) - ) - ], - ) - ), - LlmResponse( - content=Content( - role="model", - parts=[Part(text="Some files.")], - ) - ), - ] - response_iter: Iterator[LlmResponse] = iter(responses) + ], + ) + ), + LlmResponse( + content=Content( + role="model", + parts=[Part(text="Some files.")], + ) + ), + ] @classmethod def supported_models(cls) -> list[str]: return ["mcp_model"] - async def generate_content_async( - self, llm_request: LlmRequest, stream: bool = False - ) -> AsyncGenerator[LlmResponse, None]: - yield next(self.response_iter) - @pytest.mark.parametrize("use_local_model", [True, False]) @pytest.mark.asyncio @@ -455,18 +473,7 @@ async def test_mcp_agent(client: Client, use_local_model: bool): toolset_providers=[ TemporalMcpToolSetProvider( "test_set", - lambda _: McpToolset( - connection_params=StdioConnectionParams( - server_params=StdioServerParameters( - command="npx", - args=[ - "-y", - "@modelcontextprotocol/server-filesystem", - os.path.dirname(os.path.abspath(__file__)), - ], - ), - ), - ), + example_toolset, ) ], ) @@ -567,3 +574,88 @@ async def test_single_agent_telemetry(client: Client): async def test_unsetting_timeout(): model = TemporalModel("", ActivityConfig(start_to_close_timeout=None)) assert model._activity_config.get("start_to_close_timeout", None) is None + + +@pytest.mark.asyncio +async def test_agent_outside_workflow(): + """Test that an agent using TemporalModel and activity_tool works outside a Temporal workflow.""" + LLMRegistry.register(WeatherModel) + + agent = weather_agent("weather_model") + + runner = InMemoryRunner( + agent=agent, + app_name="test_app_local", + ) + + session = await runner.session_service.create_session( + app_name="test_app_local", user_id="test" + ) + + last_event = None + async with Aclosing( + runner.run_async( + user_id="test", + session_id=session.id, + new_message=types.Content( + role="user", parts=[types.Part(text="What is the weather in New York?")] + ), + ) + ) as agen: + async for event in agen: + last_event = event + + assert last_event is not None + assert last_event.content is not None + assert last_event.content.parts is not None + assert last_event.content.parts[0].text == "warm and sunny" + + +@pytest.mark.asyncio +@pytest.mark.skip # Doesn't work well in CI currently +async def test_mcp_agent_outside_workflow(): + """Test that an agent using TemporalMcpToolSet works outside a Temporal workflow.""" + LLMRegistry.register(McpModel) + + agent = mcp_agent("mcp_model") + + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name="test_app_local", user_id="test" + ) + + runner = Runner( + agent=agent, + app_name="test_app_local", + session_service=session_service, + ) + + last_event = None + async with Aclosing( + runner.run_async( + user_id="test", + session_id=session.id, + new_message=types.Content( + role="user", + parts=[types.Part(text="What files are in the current directory?")], + ), + ) + ) as agen: + async for event in agen: + last_event = event + + assert last_event is not None + assert last_event.content is not None + assert last_event.content.parts is not None + assert last_event.content.parts[0].text == "Some files." + + +@pytest.mark.asyncio +async def test_mcp_toolset_outside_workflow_no_local_toolset(): + """Test that TemporalMcpToolSet raises ValueError outside a workflow with no local_toolset.""" + toolset = TemporalMcpToolSet("test_set_no_local") + with pytest.raises( + ValueError, + match="Attempted to execute an MCP tool", + ): + await toolset.get_tools()