Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 48 additions & 24 deletions eval_protocol/mcp/mcp_multi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
import os
from contextlib import AsyncExitStack
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from dotenv import load_dotenv
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult
from openai.types import FunctionDefinition
from openai.types.chat import ChatCompletionToolParam

from eval_protocol.types.types import MCPMultiClientConfiguration
from eval_protocol.models import (
MCPConfigurationServerStdio,
MCPConfigurationServerUrl,
MCPMultiClientConfiguration,
)

load_dotenv() # load environment variables from .env

Expand All @@ -38,10 +43,10 @@ def _load_config(self, config_path: Optional[str] = None) -> MCPMultiClientConfi
"""Load MCP server configuration from file or use default"""
if config_path and os.path.exists(config_path):
with open(config_path, "r") as f:
return json.load(f)
return MCPMultiClientConfiguration(**json.load(f))

# Default configuration - can be overridden by config file
return {"mcpServers": {}}
return MCPMultiClientConfiguration(mcpServers={})

def _validate_environment_variables(self, server_name: str, required_env: List[str]) -> None:
"""Validate that required environment variables are set in os.environ"""
Expand All @@ -59,35 +64,54 @@ def _validate_environment_variables(self, server_name: str, required_env: List[s

async def connect_to_servers(self):
"""Connect to all configured MCP servers"""
if not self.config.get("mcpServers"):
if not self.config.mcpServers:
print("No MCP servers configured. Please provide a configuration file.")
return

for server_name, server_config in self.config["mcpServers"].items():
for server_name, server_config in self.config.mcpServers.items():
try:
await self._connect_to_server(server_name, server_config)
except Exception as e:
print(f"Failed to connect to server '{server_name}': {e}")

async def _connect_to_server(self, server_name: str, server_config: Dict[str, Any]):
async def _connect_to_server(
self, server_name: str, server_config: Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl]
):
"""Connect to a specific MCP server using its configuration"""
command = server_config.get("command")
args = server_config.get("args", [])
env_config = server_config.get("env", [])

if not command:
raise ValueError(f"Server '{server_name}' must have a 'command' specified")

# Validate that required environment variables are set
if env_config:
self._validate_environment_variables(server_name, env_config)

# Use the current system environment (os.environ) - don't override with config
server_params = StdioServerParameters(command=command, args=args, env=os.environ)

stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
session: ClientSession

if isinstance(server_config, MCPConfigurationServerStdio):
# Handle stdio-based MCP server
command = server_config.command
args = server_config.args
env_config = server_config.env

if not command:
raise ValueError(f"Server '{server_name}' must have a 'command' specified")

# Validate that required environment variables are set
if env_config:
self._validate_environment_variables(server_name, env_config)

# Use the current system environment (os.environ) - don't override with config
server_params = StdioServerParameters(command=command, args=args, env=os.environ)

stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))

elif isinstance(server_config, MCPConfigurationServerUrl):
# Handle HTTP-based MCP server
url = server_config.url
if not url:
raise ValueError(f"Server '{server_name}' must have a 'url' specified")

# Connect using streamable HTTP client - manage resources manually
http_transport = await self.exit_stack.enter_async_context(streamablehttp_client(url))
read_stream, write_stream, get_session_id = http_transport
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
else:
raise ValueError(f"Unsupported server configuration type: {type(server_config)}")

await session.initialize()
self.sessions[server_name] = session
Expand Down
33 changes: 30 additions & 3 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from openai.types import CompletionUsage
from openai.types.chat.chat_completion_message import (
Expand All @@ -8,11 +8,18 @@
from pydantic import BaseModel, ConfigDict, Field


class ChatCompletionContentPartTextParam(BaseModel):
text: str = Field(..., description="The text content.")
type: Literal["text"] = Field("text", description="The type of the content part.")


class Message(BaseModel):
"""Chat message model with trajectory evaluation support."""

role: str
content: Optional[str] = "" # Content can be None for tool calls in OpenAI API
role: str # assistant, user, system, tool
content: Optional[Union[str, List[ChatCompletionContentPartTextParam]]] = Field(
default="", description="The content of the message."
)
name: Optional[str] = None
tool_call_id: Optional[str] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
Expand Down Expand Up @@ -426,3 +433,23 @@ class Config:
# from pydantic import ConfigDict
# model_config = ConfigDict(extra='allow')
# For Pydantic v1, `Config.extra = "allow"` is correct.


class MCPConfigurationServerStdio(BaseModel):
"""Represents a MCP configuration server."""

command: str # command to run the MCP server
args: List[str] = Field(default_factory=list) # to pass to the command
env: List[str] = Field(default_factory=list) # List of environment variables to verify exist in the environment


class MCPConfigurationServerUrl(BaseModel):
"""Represents a Remote MCP configuration server."""

url: str # url to the MCP server


class MCPMultiClientConfiguration(BaseModel):
"""Represents a MCP configuration."""

mcpServers: Dict[str, Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl]]
27 changes: 13 additions & 14 deletions eval_protocol/pytest/default_agent_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import os
from typing import Any, List, Optional, Union

from mcp.types import CallToolResult
from mcp.types import CallToolResult, TextContent
from openai import NOT_GIVEN, NotGiven
from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam
from openai.types.chat import ChatCompletionContentPartTextParam, ChatCompletionMessage, ChatCompletionToolParam
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam

from eval_protocol.mcp.execution.policy import LiteLLMPolicy
Expand Down Expand Up @@ -57,16 +57,16 @@ async def call_agent(self) -> str:
tool_tasks.append(task)

# Execute all tool calls in parallel
tool_results = await asyncio.gather(*tool_tasks)
tool_results: List[List[TextContent]] = await asyncio.gather(*tool_tasks)

# Add all tool results to messages (they will be in the same order as tool_calls)
for tool_call, (tool_call_id, content) in zip(message["tool_calls"], tool_results):
self.messages.append(
{
"role": "tool",
"content": content,
"tool_call_id": tool_call_id,
}
Message(
role="tool",
content=content,
tool_call_id=tool_call_id,
)
)
return await self.call_agent()
return message["content"]
Expand All @@ -88,15 +88,12 @@ async def _execute_tool_call(self, tool_call_id: str, tool_name: str, tool_args_
content = self._get_content_from_tool_result(tool_result)
return tool_call_id, content

def _get_content_from_tool_result(self, tool_result: CallToolResult) -> str:
def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[TextContent]:
if tool_result.structuredContent:
return json.dumps(tool_result.structuredContent)
if len(tool_result.content) > 1:
raise NotImplementedError("Multiple content is not supported yet")
first_content = tool_result.content[0]
if first_content.type != "text":
if not all(isinstance(content, TextContent) for content in tool_result.content):
raise NotImplementedError("Non-text content is not supported yet")
return first_content.text
return tool_result.content[0].text


async def default_agent_rollout_processor(
Expand All @@ -108,4 +105,6 @@ async def default_agent_rollout_processor(
await agent.setup()
await agent.call_agent()
dataset.append(EvaluationRow(messages=agent.messages, ground_truth=row.ground_truth))
if agent.mcp_client:
await agent.mcp_client.cleanup()
return dataset
16 changes: 0 additions & 16 deletions eval_protocol/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,3 @@ class Trajectory:
termination_reason: str
conversation_history: List[Dict[str, Any]]
usage: Dict[str, int] = field(default_factory=dict)


@dataclass
class MCPConfigurationServer:
"""Represents a MCP configuration server."""

command: str # command to run the MCP server
args: List[str] # to pass to the command
env: List[str] # List of environment variables to verify exist in the environment


@dataclass
class MCPMultiClientConfiguration:
"""Represents a MCP configuration."""

mcp_servers: Dict[str, MCPConfigurationServer]
7 changes: 7 additions & 0 deletions tests/pytest/mcp_configurations/docs_mcp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"mcpServers": {
"docs.fireworks.ai": {
"url": "https://docs.fireworks.ai/mcp"
}
}
}
42 changes: 42 additions & 0 deletions tests/pytest/test_pytest_mcp_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from eval_protocol.models import EvaluateResult, Message, EvaluationRow
from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test


@evaluation_test(
input_messages=[
[
Message(
role="system",
content=(
"You are a helpful assistant that can answer questions about Fireworks.\n"
"ALWAYS provide code or commands to execute to answer the question."
),
),
Message(
role="user",
content=("Can you teach me about how to manage deployments on Fireworks"),
),
]
],
rollout_processor=default_agent_rollout_processor,
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
mode="pointwise",
mcp_config_path="tests/pytest/mcp_configurations/docs_mcp_config.json",
)
def test_pytest_mcp_url(row: EvaluationRow) -> EvaluationRow:
"""Run math evaluation on sample dataset using pytest interface."""
# filter for all tool calls
tool_calls = [msg for msg in row.messages if msg.role == "tool"]

if len(tool_calls) == 0:
row.evaluation_result = EvaluateResult(
score=0,
feedback="No tool calls made",
)
return row

row.evaluation_result = EvaluateResult(
score=1,
feedback="At least one tool call was made",
)
return row
Loading