Skip to content

Commit 4a00d06

Browse files
authored
fix: openai blocking requests do not call injected tools correctly (#72)
* fix: openai blocking requests do not call injected tools correctly Signed-off-by: Danny Kopping <danny@coder.com> * chore: drive-by renaming Signed-off-by: Danny Kopping <danny@coder.com> --------- Signed-off-by: Danny Kopping <danny@coder.com>
1 parent 7838abc commit 4a00d06

File tree

3 files changed

+62
-24
lines changed

3 files changed

+62
-24
lines changed

bridge_integration_test.go

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -727,11 +727,12 @@ func TestFallthrough(t *testing.T) {
727727
}
728728

729729
// setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools
730-
func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier {
730+
func setupMCPServerProxiesForTest(t *testing.T) (map[string]mcp.ServerProxier, *callAccumulator) {
731731
t.Helper()
732732

733733
// Setup Coder MCP integration
734-
mcpSrv := httptest.NewServer(createMockMCPSrv(t))
734+
srv, acc := createMockMCPSrv(t)
735+
mcpSrv := httptest.NewServer(srv)
735736
t.Cleanup(mcpSrv.Close)
736737

737738
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
@@ -745,7 +746,7 @@ func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier {
745746
tools := proxy.ListTools()
746747
require.NotEmpty(t, tools)
747748

748-
return map[string]mcp.ServerProxier{proxy.Name(): proxy}
749+
return map[string]mcp.ServerProxier{proxy.Name(): proxy}, acc
749750
}
750751

751752
type (
@@ -766,7 +767,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
766767
}
767768

768769
// Build the requirements & make the assertions which are common to all providers.
769-
recorderClient, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)
770+
recorderClient, mcpCalls, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)
770771

771772
// Ensure expected tool was invoked with expected input.
772773
require.Len(t, recorderClient.toolUsages, 1)
@@ -776,6 +777,11 @@ func TestAnthropicInjectedTools(t *testing.T) {
776777
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
777778
require.NoError(t, err)
778779
require.EqualValues(t, expected, actual)
780+
invocations := mcpCalls.getCallsByTool(mockToolName)
781+
require.Len(t, invocations, 1)
782+
actual, err = json.Marshal(invocations[0])
783+
require.NoError(t, err)
784+
require.EqualValues(t, expected, actual)
779785

780786
var (
781787
content *anthropic.ContentBlockUnion
@@ -847,7 +853,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
847853
}
848854

849855
// Build the requirements & make the assertions which are common to all providers.
850-
recorderClient, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)
856+
recorderClient, mcpCalls, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)
851857

852858
// Ensure expected tool was invoked with expected input.
853859
require.Len(t, recorderClient.toolUsages, 1)
@@ -857,6 +863,11 @@ func TestOpenAIInjectedTools(t *testing.T) {
857863
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
858864
require.NoError(t, err)
859865
require.EqualValues(t, expected, actual)
866+
invocations := mcpCalls.getCallsByTool(mockToolName)
867+
require.Len(t, invocations, 1)
868+
actual, err = json.Marshal(invocations[0])
869+
require.NoError(t, err)
870+
require.EqualValues(t, expected, actual)
860871

861872
var (
862873
content *openai.ChatCompletionChoice
@@ -932,7 +943,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
932943

933944
// setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests.
934945
// Kinda fugly right now, we can refactor this later.
935-
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error), createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *http.Response) {
946+
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *callAccumulator, *http.Response) {
936947
t.Helper()
937948

938949
arc := txtar.Parse(fixture)
@@ -977,11 +988,11 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
977988

978989
recorderClient := &mockRecorderClient{}
979990

980-
// Setup MCP tools.
981-
tools := setupMCPServerProxiesForTest(t)
991+
// Setup MCP mcpProxiers.
992+
mcpProxiers, acc := setupMCPServerProxiesForTest(t)
982993

983994
// Configure the bridge with injected tools.
984-
mcpMgr := mcp.NewServerProxyManager(tools)
995+
mcpMgr := mcp.NewServerProxyManager(mcpProxiers)
985996
require.NoError(t, mcpMgr.Init(ctx))
986997
b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr)
987998
require.NoError(t, err)
@@ -1008,7 +1019,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
10081019
return mockSrv.callCount.Load() == 2
10091020
}, time.Second*10, time.Millisecond*50)
10101021

1011-
return recorderClient, resp
1022+
return recorderClient, acc, resp
10121023
}
10131024

10141025
func TestErrorHandling(t *testing.T) {
@@ -1259,10 +1270,10 @@ func TestStableRequestEncoding(t *testing.T) {
12591270
t.Cleanup(cancel)
12601271

12611272
// Setup MCP tools.
1262-
tools := setupMCPServerProxiesForTest(t)
1273+
mcpProxiers, _ := setupMCPServerProxiesForTest(t)
12631274

12641275
// Configure the bridge with injected tools.
1265-
mcpMgr := mcp.NewServerProxyManager(tools)
1276+
mcpMgr := mcp.NewServerProxyManager(mcpProxiers)
12661277
require.NoError(t, mcpMgr.Init(ctx))
12671278

12681279
arc := txtar.Parse(tc.fixture)
@@ -1669,7 +1680,36 @@ func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) {
16691680

16701681
const mockToolName = "coder_list_workspaces"
16711682

1672-
func createMockMCPSrv(t *testing.T) http.Handler {
1683+
// callAccumulator tracks all tool invocations by name and each instance's arguments.
1684+
type callAccumulator struct {
1685+
calls map[string][]any
1686+
callsMu sync.Mutex
1687+
}
1688+
1689+
func newCallAccumulator() *callAccumulator {
1690+
return &callAccumulator{
1691+
calls: make(map[string][]any),
1692+
}
1693+
}
1694+
1695+
func (a *callAccumulator) addCall(tool string, args any) {
1696+
a.callsMu.Lock()
1697+
defer a.callsMu.Unlock()
1698+
1699+
a.calls[tool] = append(a.calls[tool], args)
1700+
}
1701+
1702+
func (a *callAccumulator) getCallsByTool(name string) []any {
1703+
a.callsMu.Lock()
1704+
defer a.callsMu.Unlock()
1705+
1706+
// Protect against concurrent access of the slice.
1707+
result := make([]any, len(a.calls[name]))
1708+
copy(result, a.calls[name])
1709+
return result
1710+
}
1711+
1712+
func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
16731713
t.Helper()
16741714

16751715
s := server.NewMCPServer(
@@ -1678,16 +1718,20 @@ func createMockMCPSrv(t *testing.T) http.Handler {
16781718
server.WithToolCapabilities(true),
16791719
)
16801720

1721+
// Accumulate tool calls & their arguments.
1722+
acc := newCallAccumulator()
1723+
16811724
for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} {
16821725
tool := mcplib.NewTool(name,
16831726
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
16841727
)
16851728
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
1729+
acc.addCall(request.Params.Name, request.Params.Arguments)
16861730
return mcplib.NewToolResultText("mock"), nil
16871731
})
16881732
}
16891733

1690-
return server.NewStreamableHTTPServer(s)
1734+
return server.NewStreamableHTTPServer(s), acc
16911735
}
16921736

16931737
func openaiCfg(url, key string) aibridge.OpenAIConfig {

intercept_openai_chat_blocking.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package aibridge
22

33
import (
4-
"bytes"
54
"encoding/json"
65
"fmt"
76
"net/http"
@@ -139,20 +138,15 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
139138
appendedPrevMsg = true
140139
}
141140

142-
var (
143-
args map[string]string
144-
buf bytes.Buffer
145-
)
146-
_ = json.NewEncoder(&buf).Encode(tc.Function.Arguments)
147-
_ = json.NewDecoder(&buf).Decode(&args)
141+
args := i.unmarshalArgs(tc.Function.Arguments)
148142
res, err := tool.Call(ctx, args)
149143

150144
_ = i.recorder.RecordToolUsage(ctx, &ToolUsageRecord{
151145
InterceptionID: i.ID().String(),
152146
MsgID: completion.ID,
153147
ServerURL: &tool.ServerURL,
154148
Tool: tool.Name,
155-
Args: i.unmarshalArgs(tc.Function.Arguments),
149+
Args: args,
156150
Injected: true,
157151
InvocationError: err,
158152
})

metrics_integration_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) {
236236
provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil)
237237

238238
// Setup mocked MCP server & tools.
239-
tools := setupMCPServerProxiesForTest(t)
240-
mcpMgr := mcp.NewServerProxyManager(tools)
239+
mcpProxiers, _ := setupMCPServerProxiesForTest(t)
240+
mcpMgr := mcp.NewServerProxyManager(mcpProxiers)
241241
require.NoError(t, mcpMgr.Init(ctx))
242242

243243
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, metrics, logger)

0 commit comments

Comments
 (0)