diff --git a/go/cli/internal/cli/mcp/root.go b/go/cli/internal/cli/mcp/root.go index 43d8932a8..93639baf3 100644 --- a/go/cli/internal/cli/mcp/root.go +++ b/go/cli/internal/cli/mcp/root.go @@ -18,6 +18,7 @@ Model Context Protocol servers with dynamic tool loading.`, mcpCmd.AddCommand(DeployCmd) mcpCmd.AddCommand(AddToolCmd) mcpCmd.AddCommand(RunCmd) + mcpCmd.AddCommand(ServeAgentsCmd) mcpCmd.AddCommand(SecretsCmd) return mcpCmd diff --git a/go/cli/internal/cli/mcp/serve_mcp.go b/go/cli/internal/cli/mcp/serve_mcp.go new file mode 100644 index 000000000..24ed9a326 --- /dev/null +++ b/go/cli/internal/cli/mcp/serve_mcp.go @@ -0,0 +1,209 @@ +package mcp + +import ( + "context" + "fmt" + "net/http" + "os" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/kagent-dev/kagent/go/cli/internal/config" + "github.com/kagent-dev/kagent/go/internal/a2a" + "github.com/kagent-dev/kagent/go/internal/version" + "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/spf13/cobra" + a2aclient "trpc.group/trpc-go/trpc-a2a-go/client" + "trpc.group/trpc-go/trpc-a2a-go/protocol" +) + +var ( + serveAgentsTransport string + serveAgentsHost string + serveAgentsPort int +) + +var a2aContextBySessionAndAgent sync.Map + +var fallbackInvocationCounter uint64 + +var ServeAgentsCmd = &cobra.Command{ + Use: "serve-mcp", + Short: "Serve kagent agents via MCP", + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Get() + if err != nil { + return fmt.Errorf("config: %w", err) + } + hooks := &mcpserver.Hooks{} + hooks.AddOnUnregisterSession(func(ctx context.Context, session mcpserver.ClientSession) { + sessionID := session.SessionID() + a2aContextBySessionAndAgent.Range(func(key, _ any) bool { + keyStr, ok := key.(string) + if !ok { + return true + } + if strings.HasPrefix(keyStr, sessionID+"|") { + a2aContextBySessionAndAgent.Delete(key) + } + return true + }) + }) + s := mcpserver.NewMCPServer( + "kagent-agents", + version.Version, + mcpserver.WithToolCapabilities(false), + mcpserver.WithHooks(hooks), + ) + + s.AddTool(mcp.NewTool("list_agents", + mcp.WithDescription("List invokable kagent agents (accepted + deploymentReady)"), + ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resp, err := cfg.Client().Agent.ListAgents(ctx) + if err != nil { + return mcp.NewToolResultErrorFromErr("list agents", err), nil + } + type agentSummary struct { + Ref string `json:"ref"` + Description string `json:"description,omitempty"` + } + agents := make([]agentSummary, 0) + for _, agent := range resp.Data { + if !agent.Accepted || !agent.DeploymentReady || agent.Agent == nil { + continue + } + ref := agent.Agent.Namespace + "/" + agent.Agent.Name + agents = append(agents, agentSummary{Ref: ref, Description: agent.Agent.Spec.Description}) + } + if len(agents) == 0 { + return mcp.NewToolResultStructured(agents, "No invokable agents found."), nil + } + + var fallbackText strings.Builder + for i, agent := range agents { + if i > 0 { + fallbackText.WriteByte('\n') + } + fallbackText.WriteString(agent.Ref) + if agent.Description != "" { + fallbackText.WriteString(" - ") + fallbackText.WriteString(agent.Description) + } + } + + return mcp.NewToolResultStructured(agents, fallbackText.String()), nil + }) + + s.AddTool(mcp.NewTool("invoke_agent", + mcp.WithDescription("Invoke a kagent agent via A2A"), + mcp.WithString("agent", mcp.Description("Agent name (or namespace/name)"), mcp.Required()), + mcp.WithString("task", mcp.Description("Task to run"), mcp.Required()), + ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + agentRef, err := request.RequireString("agent") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + task, err := request.RequireString("task") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + agentNS, agentName, ok := strings.Cut(agentRef, "/") + if !ok { + agentNS, agentName = cfg.Namespace, agentRef + } + agentRef = agentNS + "/" + agentName + + sessionID := "" + if session := mcpserver.ClientSessionFromContext(ctx); session != nil { + sessionID = session.SessionID() + } else if headerSessionID := request.Header.Get(mcpserver.HeaderKeySessionID); headerSessionID != "" { + sessionID = headerSessionID + } + if sessionID == "" { + sessionID = fmt.Sprintf("invocation-%d", atomic.AddUint64(&fallbackInvocationCounter, 1)) + } + contextKey := sessionID + "|" + agentRef + var contextIDPtr *string + if prior, ok := a2aContextBySessionAndAgent.Load(contextKey); ok { + if priorStr, ok := prior.(string); ok && priorStr != "" { + contextIDPtr = &priorStr + } + } + + a2aURL := fmt.Sprintf("%s/api/a2a/%s/%s", cfg.KAgentURL, agentNS, agentName) + client, err := a2aclient.NewA2AClient(a2aURL, a2aclient.WithTimeout(cfg.Timeout)) + if err != nil { + return mcp.NewToolResultErrorFromErr("a2a client", err), nil + } + result, err := client.SendMessage(ctx, protocol.SendMessageParams{Message: protocol.Message{ + Kind: protocol.KindMessage, Role: protocol.MessageRoleUser, ContextID: contextIDPtr, Parts: []protocol.Part{protocol.NewTextPart(task)}, + }}) + if err != nil { + return mcp.NewToolResultErrorFromErr("a2a send", err), nil + } + + var responseText, newContextID string + switch a2aResult := result.Result.(type) { + case *protocol.Message: + responseText = a2a.ExtractText(*a2aResult) + if a2aResult.ContextID != nil { + newContextID = *a2aResult.ContextID + } + case *protocol.Task: + newContextID = a2aResult.ContextID + if a2aResult.Status.Message != nil { + responseText = a2a.ExtractText(*a2aResult.Status.Message) + } + for _, artifact := range a2aResult.Artifacts { + responseText += a2a.ExtractText(protocol.Message{Parts: artifact.Parts}) + } + } + if responseText == "" { + raw, err := result.MarshalJSON() + if err != nil { + return mcp.NewToolResultErrorFromErr("marshal result", err), nil + } + responseText = string(raw) + } + if newContextID != "" { + a2aContextBySessionAndAgent.Store(contextKey, newContextID) + } + return mcp.NewToolResultStructured(map[string]any{ + "agent": agentRef, + "context_id": newContextID, + "text": responseText, + }, responseText), nil + }) + + switch strings.ToLower(serveAgentsTransport) { + case "stdio": + stdioServer := mcpserver.NewStdioServer(s) + return stdioServer.Listen(cmd.Context(), os.Stdin, os.Stdout) + case "http": + addr := fmt.Sprintf("%s:%d", serveAgentsHost, serveAgentsPort) + cmd.PrintErrf("MCP server listening on http://%s/mcp\n", addr) + httpServer := mcpserver.NewStreamableHTTPServer(s) + go func() { + <-cmd.Context().Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = httpServer.Shutdown(shutdownCtx) + }() + if err := httpServer.Start(addr); err != nil && err != http.ErrServerClosed { + return err + } + return nil + default: + return fmt.Errorf("invalid transport %q (expected stdio or http)", serveAgentsTransport) + } + }, +} + +func init() { + ServeAgentsCmd.Flags().StringVar(&serveAgentsTransport, "transport", "stdio", "Transport mode (stdio or http)") + ServeAgentsCmd.Flags().StringVar(&serveAgentsHost, "host", "127.0.0.1", "HTTP host to bind (when --transport http)") + ServeAgentsCmd.Flags().IntVar(&serveAgentsPort, "port", 3000, "HTTP port to bind (when --transport http)") +} diff --git a/go/test/e2e/invoke_api_test.go b/go/test/e2e/invoke_api_test.go index ddeda2156..7b37bd044 100644 --- a/go/test/e2e/invoke_api_test.go +++ b/go/test/e2e/invoke_api_test.go @@ -508,7 +508,7 @@ func TestE2EInvokeDeclarativeAgentWithMcpServerTool(t *testing.T) { // Setup Kubernetes client (include v1alpha1 for MCPServer) cli := setupK8sClient(t, true) mcpServer := setupMCPServer(t, cli) - // Define tools + // Define tools tools := []*v1alpha2.Tool{ { Type: v1alpha2.ToolProviderType_McpServer, @@ -533,7 +533,7 @@ func TestE2EInvokeDeclarativeAgentWithMcpServerTool(t *testing.T) { // Run tests t.Run("sync_invocation", func(t *testing.T) { - runSyncTest(t, a2aClient, "add 3 and 5", "8", nil) + runSyncTest(t, a2aClient, "add 3 and 5. To add two numbers, call the tool named get-sum", "8", nil) }) t.Run("streaming_invocation", func(t *testing.T) { diff --git a/go/test/e2e/mcp_serve_agents_test.go b/go/test/e2e/mcp_serve_agents_test.go new file mode 100644 index 000000000..235f21be8 --- /dev/null +++ b/go/test/e2e/mcp_serve_agents_test.go @@ -0,0 +1,144 @@ +package e2e_test + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestE2EInvokeAgentThroughMCPServeAgents(t *testing.T) { + // Setup mock server (so agent responses are deterministic and don't hit real LLMs) + baseURL, stopServer := setupMockServer(t, "mocks/invoke_mcp_serve_agents.json") + defer stopServer() + + // Setup Kubernetes resources for a known-good agent + cli := setupK8sClient(t, false) + modelCfg := setupModelConfig(t, cli, baseURL) + agent := setupAgentWithOptions(t, cli, modelCfg.Name, nil, AgentOptions{ + Name: "kebab-agent", + }) + + kagentURL := os.Getenv("KAGENT_URL") + if kagentURL == "" { + kagentURL = "http://localhost:8083" + } + + _, testFile, _, ok := runtime.Caller(0) + require.True(t, ok) + goModuleRoot := filepath.Clean(filepath.Join(filepath.Dir(testFile), "../..")) + + kagentBin := filepath.Join(t.TempDir(), "kagent") + build := exec.Command("go", "build", "-o", kagentBin, "./cli/cmd/kagent") + build.Dir = goModuleRoot + buildOutput, err := build.CombinedOutput() + require.NoError(t, err, string(buildOutput)) + + homeDir := t.TempDir() + cfgDir := filepath.Join(homeDir, ".kagent") + require.NoError(t, os.MkdirAll(cfgDir, 0755)) + cfgPath := filepath.Join(cfgDir, "config.yaml") + require.NoError(t, os.WriteFile(cfgPath, []byte(fmt.Sprintf("kagent_url: %s\nnamespace: kagent\ntimeout: 300s\n", kagentURL)), 0644)) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, kagentBin, "mcp", "serve-agents") + cmd.Env = append(os.Environ(), "HOME="+homeDir) + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + stdin, err := cmd.StdinPipe() + require.NoError(t, err) + var stderr bytes.Buffer + cmd.Stderr = &stderr + require.NoError(t, cmd.Start()) + t.Cleanup(func() { + _ = stdin.Close() + _ = cmd.Process.Kill() + _ = cmd.Wait() + }) + + lines := make(chan string, 32) + go func() { + scanner := bufio.NewScanner(stdout) + for scanner.Scan() { + lines <- scanner.Text() + } + close(lines) + }() + + writeLine := func(line string) { + _, _ = fmt.Fprintln(stdin, line) + } + + readResponse := func(wantID int) json.RawMessage { + deadline := time.NewTimer(15 * time.Second) + defer deadline.Stop() + for { + select { + case line, ok := <-lines: + require.True(t, ok, stderr.String()) + var msg struct { + ID int `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + } + require.NoError(t, json.Unmarshal([]byte(line), &msg), line) + if msg.ID != wantID { + continue + } + require.Nil(t, msg.Error, line) + return msg.Result + case <-deadline.C: + t.Fatalf("timed out waiting for id=%d; stderr=%s", wantID, stderr.String()) + } + } + } + + writeLine(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"e2e","version":"0.0.0"}}}`) + _ = readResponse(1) + writeLine(`{"jsonrpc":"2.0","method":"notifications/initialized","params":{}}`) + + writeLine(`{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}`) + toolsList := readResponse(2) + var listResult struct { + Tools []struct { + Name string `json:"name"` + } `json:"tools"` + } + require.NoError(t, json.Unmarshal(toolsList, &listResult), string(toolsList)) + require.GreaterOrEqual(t, len(listResult.Tools), 2) + toolNames := make([]string, 0, len(listResult.Tools)) + for _, tool := range listResult.Tools { + toolNames = append(toolNames, tool.Name) + } + require.Contains(t, toolNames, "list_agents") + require.Contains(t, toolNames, "invoke_agent") + + writeLine(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"list_agents"}}`) + agentsResult := readResponse(3) + var callResult struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + } + require.NoError(t, json.Unmarshal(agentsResult, &callResult), string(agentsResult)) + require.NotEmpty(t, callResult.Content) + require.Contains(t, callResult.Content[0].Text, agent.Namespace+"/"+agent.Name) + + writeLine(fmt.Sprintf(`{"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"invoke_agent","arguments":{"agent":%q,"task":"What can you do?"}}}`, agent.Name)) + invokeResult := readResponse(4) + require.NoError(t, json.Unmarshal(invokeResult, &callResult), string(invokeResult)) + require.NotEmpty(t, callResult.Content) + require.Contains(t, callResult.Content[0].Text, "kebab") +} diff --git a/go/test/e2e/mocks/invoke_mcp_serve_agents.json b/go/test/e2e/mocks/invoke_mcp_serve_agents.json new file mode 100644 index 000000000..efd569d28 --- /dev/null +++ b/go/test/e2e/mocks/invoke_mcp_serve_agents.json @@ -0,0 +1,31 @@ +{ + "openai": [ + { + "name": "serve_agents_request", + "match": { + "match_type": "contains", + "message": { + "content": "What can you do?", + "role": "user" + } + }, + "response": { + "id": "chatcmpl-1", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4.1-mini", + "choices": [ + { + "index": 0, + "role": "assistant", + "message": { + "content": "I can answer questions and help you with tasks. Also: kebab.", + "role": "assistant" + }, + "finish_reason": "stop" + } + ] + } + } + ] +}