From 1728be989f2fe25ffa63ceea366da7afe0464f9c Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:17:59 +0000 Subject: [PATCH] refactor: integrate callAccumulator into mockMCPServer struct This addresses the tech debt identified in issue #73 by: 1. Creating mockMCPServer struct that combines: - MCP server proxies (Proxies map) - Call tracking (previously separate callAccumulator) 2. Replacing setupMCPServerProxiesForTest with newMockMCPServer: - Returns single struct instead of multiple values - Integrates all MCP server setup into one place 3. Updating setupInjectedToolTest: - Returns 3 values instead of 4 (recorder, mcpMock, resp) - Callers use mcpMock.GetToolCalls() instead of separate accumulator 4. Adding createMockMCPSrvHandler for cases needing custom setup (e.g., custom server name in trace tests) The changes reduce cognitive overhead by keeping related data together and simplifying function signatures. Resolves #73 --- bridge_integration_test.go | 125 ++++++++++++++++++++---------------- metrics_integration_test.go | 4 +- trace_integration_test.go | 10 +-- 3 files changed, 76 insertions(+), 63 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index b4ea460..b84c8dd 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -724,28 +724,7 @@ func TestFallthrough(t *testing.T) { } } -// setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools -func setupMCPServerProxiesForTest(t *testing.T, tracer trace.Tracer) (map[string]mcp.ServerProxier, *callAccumulator) { - t.Helper() - - // Setup Coder MCP integration - srv, acc := createMockMCPSrv(t) - mcpSrv := httptest.NewServer(srv) - t.Cleanup(mcpSrv.Close) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer) - require.NoError(t, err) - // Initialize MCP client, fetch tools, and inject into bridge - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - require.NoError(t, proxy.Init(ctx)) - tools := proxy.ListTools() - require.NotEmpty(t, tools) - - return map[string]mcp.ServerProxier{proxy.Name(): proxy}, acc -} type ( configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) @@ -766,7 +745,7 @@ func TestAnthropicInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) + recorderClient, mcpMock, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) // Ensure expected tool was invoked with expected input. require.Len(t, recorderClient.toolUsages, 1) @@ -776,7 +755,7 @@ func TestAnthropicInjectedTools(t *testing.T) { actual, err := json.Marshal(recorderClient.toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mcpCalls.getCallsByTool(mockToolName) + invocations := mcpMock.GetToolCalls(mockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -853,7 +832,7 @@ func TestOpenAIInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) + recorderClient, mcpMock, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) // Ensure expected tool was invoked with expected input. require.Len(t, recorderClient.toolUsages, 1) @@ -863,7 +842,7 @@ func TestOpenAIInjectedTools(t *testing.T) { actual, err := json.Marshal(recorderClient.toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mcpCalls.getCallsByTool(mockToolName) + invocations := mcpMock.GetToolCalls(mockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -942,8 +921,7 @@ func TestOpenAIInjectedTools(t *testing.T) { } // setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests. -// Kinda fugly right now, we can refactor this later. -func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) { +func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *mockMCPServer, *http.Response) { t.Helper() arc := txtar.Parse(fixture) @@ -988,11 +966,11 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu recorderClient := &mockRecorderClient{} - // Setup MCP mcpProxiers. - mcpProxiers, acc := setupMCPServerProxiesForTest(t, testTracer) + // Setup MCP server with integrated call tracking. + mcpMock := newMockMCPServer(t, testTracer) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) + mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer) require.NoError(t, mcpMgr.Init(ctx)) b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr) require.NoError(t, err) @@ -1019,7 +997,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu return mockSrv.callCount.Load() == 2 }, time.Second*10, time.Millisecond*50) - return recorderClient, acc, mcpProxiers, resp + return recorderClient, mcpMock, resp } func TestErrorHandling(t *testing.T) { @@ -1277,10 +1255,10 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools. - mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) + mcpMock := newMockMCPServer(t, testTracer) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) + mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer) require.NoError(t, mcpMgr.Init(ctx)) arc := txtar.Parse(tc.fixture) @@ -1689,58 +1667,93 @@ func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) { const mockToolName = "coder_list_workspaces" -// callAccumulator tracks all tool invocations by name and each instance's arguments. -type callAccumulator struct { +// mockMCPServer combines the MCP server proxy with call tracking. +// This addresses the tech debt of having callAccumulator as a separate return value. +type mockMCPServer struct { + Proxies map[string]mcp.ServerProxier + calls map[string][]any callsMu sync.Mutex } -func newCallAccumulator() *callAccumulator { - return &callAccumulator{ - calls: make(map[string][]any), - } +func (m *mockMCPServer) addCall(tool string, args any) { + m.callsMu.Lock() + defer m.callsMu.Unlock() + m.calls[tool] = append(m.calls[tool], args) } -func (a *callAccumulator) addCall(tool string, args any) { - a.callsMu.Lock() - defer a.callsMu.Unlock() - - a.calls[tool] = append(a.calls[tool], args) -} - -func (a *callAccumulator) getCallsByTool(name string) []any { - a.callsMu.Lock() - defer a.callsMu.Unlock() +// GetToolCalls returns all recorded invocations for a specific tool. +func (m *mockMCPServer) GetToolCalls(name string) []any { + m.callsMu.Lock() + defer m.callsMu.Unlock() // Protect against concurrent access of the slice. - result := make([]any, len(a.calls[name])) - copy(result, a.calls[name]) + result := make([]any, len(m.calls[name])) + copy(result, m.calls[name]) return result } -func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) { +func newMockMCPServer(t *testing.T, tracer trace.Tracer) *mockMCPServer { t.Helper() + mock := &mockMCPServer{ + calls: make(map[string][]any), + } + s := server.NewMCPServer( "Mock coder MCP server", "1.0.0", server.WithToolCapabilities(true), ) - // Accumulate tool calls & their arguments. - acc := newCallAccumulator() + for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} { + tool := mcplib.NewTool(name, + mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), + ) + s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + mock.addCall(request.Params.Name, request.Params.Arguments) + return mcplib.NewToolResultText("mock"), nil + }) + } + + mcpSrv := httptest.NewServer(server.NewStreamableHTTPServer(s)) + t.Cleanup(mcpSrv.Close) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + require.NoError(t, proxy.Init(ctx)) + tools := proxy.ListTools() + require.NotEmpty(t, tools) + + mock.Proxies = map[string]mcp.ServerProxier{proxy.Name(): proxy} + return mock +} + +// createMockMCPSrvHandler creates just the HTTP handler for the mock MCP server. +// Use this when you need custom server configuration (e.g., custom server name). +func createMockMCPSrvHandler(t *testing.T) http.Handler { + t.Helper() + + s := server.NewMCPServer( + "Mock coder MCP server", + "1.0.0", + server.WithToolCapabilities(true), + ) for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} { tool := mcplib.NewTool(name, mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), ) s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { - acc.addCall(request.Params.Name, request.Params.Arguments) return mcplib.NewToolResultText("mock"), nil }) } - return server.NewStreamableHTTPServer(s), acc + return server.NewStreamableHTTPServer(s) } func openaiCfg(url, key string) aibridge.OpenAIConfig { diff --git a/metrics_integration_test.go b/metrics_integration_test.go index f326dec..387a6f6 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -237,8 +237,8 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) // Setup mocked MCP server & tools. - mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) + mcpMock := newMockMCPServer(t, testTracer) + mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer) require.NoError(t, mcpMgr.Init(ctx)) bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, logger, metrics, testTracer) diff --git a/trace_integration_test.go b/trace_integration_test.go index ee6574d..e7eadc5 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -346,7 +346,7 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc) + recorderClient, mcpMock, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc) defer resp.Body.Close() @@ -358,7 +358,7 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) { model = "beddel" } - for _, proxy := range proxies { + for _, proxy := range mcpMock.Proxies { require.NotEmpty(t, proxy.ListTools()) tool := proxy.ListTools()[0] @@ -607,14 +607,14 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc) + recorderClient, mcpMock, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc) defer resp.Body.Close() require.Len(t, recorderClient.interceptions, 1) intcID := recorderClient.interceptions[0].ID - for _, proxy := range proxies { + for _, proxy := range mcpMock.Proxies { require.NotEmpty(t, proxy.ListTools()) tool := proxy.ListTools()[0] @@ -687,7 +687,7 @@ func TestNewServerProxyManagerTraces(t *testing.T) { defer func() { _ = tp.Shutdown(t.Context()) }() serverName := "serverName" - srv, _ := createMockMCPSrv(t) + srv := createMockMCPSrvHandler(t) mcpSrv := httptest.NewServer(srv) t.Cleanup(mcpSrv.Close)