Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 15 additions & 26 deletions cmd/src/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"strings"

"github.com/sourcegraph/src-cli/internal/api"
"github.com/sourcegraph/src-cli/internal/mcp"

"github.com/sourcegraph/sourcegraph/lib/errors"
Expand Down Expand Up @@ -36,8 +35,8 @@ func mcpMain(args []string) error {
apiClient := cfg.apiClient(nil, mcpFlagSet.Output())

ctx := context.Background()
tools, err := mcp.FetchToolDefinitions(ctx, apiClient)
if err != nil {
registry := mcp.NewToolRegistry()
if err := registry.LoadTools(ctx, apiClient); err != nil {
return err
}

Expand All @@ -49,7 +48,7 @@ func mcpMain(args []string) error {
subcmd := args[0]
if subcmd == "list-tools" {
fmt.Println("The following tools are available:")
for name := range tools {
for name := range registry.All() {
fmt.Printf(" %s\n", name)
}
fmt.Println("\nUSAGE:")
Expand All @@ -58,7 +57,7 @@ func mcpMain(args []string) error {
fmt.Println(" src mcp <tool-name> -h List the available flags of a tool")
return nil
}
tool, ok := tools[subcmd]
tool, ok := registry.Get(subcmd)
if !ok {
return errors.Newf("tool definition for %q not found - run src mcp list-tools to see a list of available tools", subcmd)
}
Expand All @@ -81,7 +80,17 @@ func mcpMain(args []string) error {
return err
}

return handleMcpTool(context.Background(), apiClient, tool, vars)
result, err := registry.CallTool(ctx, apiClient, tool.Name, vars)
if err != nil {
return err
}

output, err := json.MarshalIndent(result, "", " ")
if err != nil {
return err
}
fmt.Println(string(output))
return nil
}

func printSchemas(tool *mcp.ToolDef) error {
Expand Down Expand Up @@ -111,23 +120,3 @@ func validateToolArgs(inputSchema mcp.SchemaObject, args []string, vars map[stri

return nil
}

func handleMcpTool(ctx context.Context, client api.Client, tool *mcp.ToolDef, vars map[string]any) error {
resp, err := mcp.DoToolCall(ctx, client, tool.RawName, vars)
if err != nil {
return err
}

result, err := mcp.DecodeToolResponse(resp)
if err != nil {
return err
}
defer resp.Body.Close()

output, err := json.MarshalIndent(result, "", " ")
if err != nil {
return err
}
fmt.Println(string(output))
return nil
}
19 changes: 10 additions & 9 deletions internal/mcp/mcp_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import (
"github.com/sourcegraph/sourcegraph/lib/errors"
)

const McpURLPath = ".api/mcp/v1"
const MCPURLPath = ".api/mcp/v1"
const MCPDeepSearchURLPath = ".api/mcp/deepsearch"

func FetchToolDefinitions(ctx context.Context, client api.Client) (map[string]*ToolDef, error) {
resp, err := doJSONRPC(ctx, client, "tools/list", nil)
func fetchToolDefinitions(ctx context.Context, client api.Client, endpoint string) (map[string]*ToolDef, error) {
resp, err := doJSONRPC(ctx, client, endpoint, "tools/list", nil)
if err != nil {
return nil, errors.Wrap(err, "failed to list tools from mcp endpoint")
}
Expand Down Expand Up @@ -44,7 +45,7 @@ func FetchToolDefinitions(ctx context.Context, client api.Client) (map[string]*T
return loadToolDefinitions(rpcResp.Result)
}

func DoToolCall(ctx context.Context, client api.Client, tool string, vars map[string]any) (*http.Response, error) {
func doToolCall(ctx context.Context, client api.Client, endpoint string, tool string, vars map[string]any) (*http.Response, error) {
params := struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
Expand All @@ -53,10 +54,10 @@ func DoToolCall(ctx context.Context, client api.Client, tool string, vars map[st
Arguments: vars,
}

return doJSONRPC(ctx, client, "tools/call", params)
return doJSONRPC(ctx, client, endpoint, "tools/call", params)
}

func doJSONRPC(ctx context.Context, client api.Client, method string, params any) (*http.Response, error) {
func doJSONRPC(ctx context.Context, client api.Client, endpoint string, method string, params any) (*http.Response, error) {
jsonRPC := struct {
Version string `json:"jsonrpc"`
ID int `json:"id"`
Expand All @@ -75,7 +76,7 @@ func doJSONRPC(ctx context.Context, client api.Client, method string, params any
}
buf.Write(data)

req, err := client.NewHTTPRequest(ctx, http.MethodPost, McpURLPath, buf)
req, err := client.NewHTTPRequest(ctx, http.MethodPost, endpoint, buf)
if err != nil {
return nil, err
}
Expand All @@ -91,13 +92,13 @@ func doJSONRPC(ctx context.Context, client api.Client, method string, params any
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
resp.Body.Close()
return nil, errors.Newf("MCP endpoint %s returned %d: %s",
McpURLPath, resp.StatusCode, strings.TrimSpace(string(body)))
endpoint, resp.StatusCode, strings.TrimSpace(string(body)))
}

return resp, nil
}

func DecodeToolResponse(resp *http.Response) (map[string]json.RawMessage, error) {
func decodeToolResponse(resp *http.Response) (map[string]json.RawMessage, error) {
data, err := readSSEResponseData(resp)
if err != nil {
return nil, err
Expand Down
81 changes: 81 additions & 0 deletions internal/mcp/registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package mcp

import (
"context"
"encoding/json"
"iter"

"github.com/sourcegraph/src-cli/internal/api"

"github.com/sourcegraph/sourcegraph/lib/errors"
)

// ToolRegistry keeps track of tools and the endpoints they originated from
type ToolRegistry struct {
tools map[string]*ToolDef
endpoints map[string]string
}

func NewToolRegistry() *ToolRegistry {
return &ToolRegistry{
tools: make(map[string]*ToolDef),
endpoints: make(map[string]string),
}
}

// LoadTools loads the tool definitions from the Mcp tool endpoints constants McpURLPath and McpDeepSearchURLPath
func (r *ToolRegistry) LoadTools(ctx context.Context, client api.Client) error {
endpoints := []string{MCPURLPath, MCPDeepSearchURLPath}

var errs []error
for _, endpoint := range endpoints {
tools, err := fetchToolDefinitions(ctx, client, endpoint)
if err != nil {
errs = append(errs, errors.Wrapf(err, "failed to load tools from %s", endpoint))
continue
}
r.register(endpoint, tools)
}

if len(errs) > 0 {
return errors.Append(nil, errs...)
}
return nil
}

// register associates a collection of tools with the given endpoint
func (r *ToolRegistry) register(endpoint string, tools map[string]*ToolDef) {
for name, def := range tools {
r.tools[name] = def
r.endpoints[name] = endpoint
}
}

// Get returns the tool definition for the given name
func (r *ToolRegistry) Get(name string) (*ToolDef, bool) {
tool, ok := r.tools[name]
return tool, ok
}

// CallTool calls the given tool with the given arguments. It constructs the Tool request and decodes the Tool response
func (r *ToolRegistry) CallTool(ctx context.Context, client api.Client, name string, args map[string]any) (map[string]json.RawMessage, error) {
tool := r.tools[name]
endpoint := r.endpoints[name]
resp, err := doToolCall(ctx, client, endpoint, tool.RawName, args)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return decodeToolResponse(resp)
}

// All returns an iterator that yields the name and Tool definition of all registered tools
func (r *ToolRegistry) All() iter.Seq2[string, *ToolDef] {
return func(yield func(string, *ToolDef) bool) {
for name, def := range r.tools {
if !yield(name, def) {
return
}
}
}
}
Loading