diff --git a/eval_protocol/mcp/mcp_multi_client.py b/eval_protocol/mcp/mcp_multi_client.py index 35c20d61..38f2dbee 100644 --- a/eval_protocol/mcp/mcp_multi_client.py +++ b/eval_protocol/mcp/mcp_multi_client.py @@ -2,16 +2,21 @@ import os from contextlib import AsyncExitStack from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from dotenv import load_dotenv from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamablehttp_client from mcp.types import CallToolResult from openai.types import FunctionDefinition from openai.types.chat import ChatCompletionToolParam -from eval_protocol.types.types import MCPMultiClientConfiguration +from eval_protocol.models import ( + MCPConfigurationServerStdio, + MCPConfigurationServerUrl, + MCPMultiClientConfiguration, +) load_dotenv() # load environment variables from .env @@ -38,10 +43,10 @@ def _load_config(self, config_path: Optional[str] = None) -> MCPMultiClientConfi """Load MCP server configuration from file or use default""" if config_path and os.path.exists(config_path): with open(config_path, "r") as f: - return json.load(f) + return MCPMultiClientConfiguration(**json.load(f)) # Default configuration - can be overridden by config file - return {"mcpServers": {}} + return MCPMultiClientConfiguration(mcpServers={}) def _validate_environment_variables(self, server_name: str, required_env: List[str]) -> None: """Validate that required environment variables are set in os.environ""" @@ -59,35 +64,54 @@ def _validate_environment_variables(self, server_name: str, required_env: List[s async def connect_to_servers(self): """Connect to all configured MCP servers""" - if not self.config.get("mcpServers"): + if not self.config.mcpServers: print("No MCP servers configured. Please provide a configuration file.") return - for server_name, server_config in self.config["mcpServers"].items(): + for server_name, server_config in self.config.mcpServers.items(): try: await self._connect_to_server(server_name, server_config) except Exception as e: print(f"Failed to connect to server '{server_name}': {e}") - async def _connect_to_server(self, server_name: str, server_config: Dict[str, Any]): + async def _connect_to_server( + self, server_name: str, server_config: Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl] + ): """Connect to a specific MCP server using its configuration""" - command = server_config.get("command") - args = server_config.get("args", []) - env_config = server_config.get("env", []) - - if not command: - raise ValueError(f"Server '{server_name}' must have a 'command' specified") - - # Validate that required environment variables are set - if env_config: - self._validate_environment_variables(server_name, env_config) - - # Use the current system environment (os.environ) - don't override with config - server_params = StdioServerParameters(command=command, args=args, env=os.environ) - - stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) - stdio, write = stdio_transport - session = await self.exit_stack.enter_async_context(ClientSession(stdio, write)) + session: ClientSession + + if isinstance(server_config, MCPConfigurationServerStdio): + # Handle stdio-based MCP server + command = server_config.command + args = server_config.args + env_config = server_config.env + + if not command: + raise ValueError(f"Server '{server_name}' must have a 'command' specified") + + # Validate that required environment variables are set + if env_config: + self._validate_environment_variables(server_name, env_config) + + # Use the current system environment (os.environ) - don't override with config + server_params = StdioServerParameters(command=command, args=args, env=os.environ) + + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + stdio, write = stdio_transport + session = await self.exit_stack.enter_async_context(ClientSession(stdio, write)) + + elif isinstance(server_config, MCPConfigurationServerUrl): + # Handle HTTP-based MCP server + url = server_config.url + if not url: + raise ValueError(f"Server '{server_name}' must have a 'url' specified") + + # Connect using streamable HTTP client - manage resources manually + http_transport = await self.exit_stack.enter_async_context(streamablehttp_client(url)) + read_stream, write_stream, get_session_id = http_transport + session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream)) + else: + raise ValueError(f"Unsupported server configuration type: {type(server_config)}") await session.initialize() self.sessions[server_name] = session diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 91a3144d..ba227c22 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from openai.types import CompletionUsage from openai.types.chat.chat_completion_message import ( @@ -8,11 +8,18 @@ from pydantic import BaseModel, ConfigDict, Field +class ChatCompletionContentPartTextParam(BaseModel): + text: str = Field(..., description="The text content.") + type: Literal["text"] = Field("text", description="The type of the content part.") + + class Message(BaseModel): """Chat message model with trajectory evaluation support.""" - role: str - content: Optional[str] = "" # Content can be None for tool calls in OpenAI API + role: str # assistant, user, system, tool + content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]] = Field( + default="", description="The content of the message." + ) name: Optional[str] = None tool_call_id: Optional[str] = None tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None @@ -426,3 +433,23 @@ class Config: # from pydantic import ConfigDict # model_config = ConfigDict(extra='allow') # For Pydantic v1, `Config.extra = "allow"` is correct. + + +class MCPConfigurationServerStdio(BaseModel): + """Represents a MCP configuration server.""" + + command: str # command to run the MCP server + args: List[str] = Field(default_factory=list) # to pass to the command + env: List[str] = Field(default_factory=list) # List of environment variables to verify exist in the environment + + +class MCPConfigurationServerUrl(BaseModel): + """Represents a Remote MCP configuration server.""" + + url: str # url to the MCP server + + +class MCPMultiClientConfiguration(BaseModel): + """Represents a MCP configuration.""" + + mcpServers: Dict[str, Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl]] diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index a753a8dc..00a61692 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -3,9 +3,9 @@ import os from typing import Any, List, Optional, Union -from mcp.types import CallToolResult +from mcp.types import CallToolResult, TextContent from openai import NOT_GIVEN, NotGiven -from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam +from openai.types.chat import ChatCompletionContentPartTextParam, ChatCompletionMessage, ChatCompletionToolParam from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam from eval_protocol.mcp.execution.policy import LiteLLMPolicy @@ -57,16 +57,16 @@ async def call_agent(self) -> str: tool_tasks.append(task) # Execute all tool calls in parallel - tool_results = await asyncio.gather(*tool_tasks) + tool_results: List[List[TextContent]] = await asyncio.gather(*tool_tasks) # Add all tool results to messages (they will be in the same order as tool_calls) for tool_call, (tool_call_id, content) in zip(message["tool_calls"], tool_results): self.messages.append( - { - "role": "tool", - "content": content, - "tool_call_id": tool_call_id, - } + Message( + role="tool", + content=content, + tool_call_id=tool_call_id, + ) ) return await self.call_agent() return message["content"] @@ -88,15 +88,12 @@ async def _execute_tool_call(self, tool_call_id: str, tool_name: str, tool_args_ content = self._get_content_from_tool_result(tool_result) return tool_call_id, content - def _get_content_from_tool_result(self, tool_result: CallToolResult) -> str: + def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[TextContent]: if tool_result.structuredContent: return json.dumps(tool_result.structuredContent) - if len(tool_result.content) > 1: - raise NotImplementedError("Multiple content is not supported yet") - first_content = tool_result.content[0] - if first_content.type != "text": + if not all(isinstance(content, TextContent) for content in tool_result.content): raise NotImplementedError("Non-text content is not supported yet") - return first_content.text + return tool_result.content[0].text async def default_agent_rollout_processor( @@ -108,4 +105,6 @@ async def default_agent_rollout_processor( await agent.setup() await agent.call_agent() dataset.append(EvaluationRow(messages=agent.messages, ground_truth=row.ground_truth)) + if agent.mcp_client: + await agent.mcp_client.cleanup() return dataset diff --git a/eval_protocol/types/types.py b/eval_protocol/types/types.py index 9e37f391..953f6aa6 100644 --- a/eval_protocol/types/types.py +++ b/eval_protocol/types/types.py @@ -71,19 +71,3 @@ class Trajectory: termination_reason: str conversation_history: List[Dict[str, Any]] usage: Dict[str, int] = field(default_factory=dict) - - -@dataclass -class MCPConfigurationServer: - """Represents a MCP configuration server.""" - - command: str # command to run the MCP server - args: List[str] # to pass to the command - env: List[str] # List of environment variables to verify exist in the environment - - -@dataclass -class MCPMultiClientConfiguration: - """Represents a MCP configuration.""" - - mcp_servers: Dict[str, MCPConfigurationServer] diff --git a/tests/pytest/mcp_configurations/docs_mcp_config.json b/tests/pytest/mcp_configurations/docs_mcp_config.json new file mode 100644 index 00000000..5ba3712e --- /dev/null +++ b/tests/pytest/mcp_configurations/docs_mcp_config.json @@ -0,0 +1,7 @@ +{ + "mcpServers": { + "docs.fireworks.ai": { + "url": "https://docs.fireworks.ai/mcp" + } + } +} diff --git a/tests/pytest/test_pytest_mcp_url.py b/tests/pytest/test_pytest_mcp_url.py new file mode 100644 index 00000000..78258199 --- /dev/null +++ b/tests/pytest/test_pytest_mcp_url.py @@ -0,0 +1,42 @@ +from eval_protocol.models import EvaluateResult, Message, EvaluationRow +from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test + + +@evaluation_test( + input_messages=[ + [ + Message( + role="system", + content=( + "You are a helpful assistant that can answer questions about Fireworks.\n" + "ALWAYS provide code or commands to execute to answer the question." + ), + ), + Message( + role="user", + content=("Can you teach me about how to manage deployments on Fireworks"), + ), + ] + ], + rollout_processor=default_agent_rollout_processor, + model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"], + mode="pointwise", + mcp_config_path="tests/pytest/mcp_configurations/docs_mcp_config.json", +) +def test_pytest_mcp_url(row: EvaluationRow) -> EvaluationRow: + """Run math evaluation on sample dataset using pytest interface.""" + # filter for all tool calls + tool_calls = [msg for msg in row.messages if msg.role == "tool"] + + if len(tool_calls) == 0: + row.evaluation_result = EvaluateResult( + score=0, + feedback="No tool calls made", + ) + return row + + row.evaluation_result = EvaluateResult( + score=1, + feedback="At least one tool call was made", + ) + return row