Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion internal/mcp/mcp_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
)

const MCPURLPath = ".api/mcp/v1"
const MCPDeepSearchURLPath = ".api/mcp/deepsearch"

func fetchToolDefinitions(ctx context.Context, client api.Client, endpoint string) (map[string]*ToolDef, error) {
resp, err := doJSONRPC(ctx, client, endpoint, "tools/list", nil)
Expand Down
39 changes: 8 additions & 31 deletions internal/mcp/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,51 +8,29 @@ import (
"slices"

"github.com/sourcegraph/src-cli/internal/api"

"github.com/sourcegraph/sourcegraph/lib/errors"
)

// ToolRegistry keeps track of tools and the endpoints they originated from
type ToolRegistry struct {
tools map[string]*ToolDef
endpoints map[string]string
tools map[string]*ToolDef
}

func NewToolRegistry() *ToolRegistry {
return &ToolRegistry{
tools: make(map[string]*ToolDef),
endpoints: make(map[string]string),
tools: make(map[string]*ToolDef),
}
}

// LoadTools loads the tool definitions from the Mcp tool endpoints constants McpURLPath and McpDeepSearchURLPath
// LoadTools loads the tool definitions from the Mcp tool endpoints constants McpURLPath
func (r *ToolRegistry) LoadTools(ctx context.Context, client api.Client) error {
endpoints := []string{MCPURLPath, MCPDeepSearchURLPath}

var errs []error
for _, endpoint := range endpoints {
tools, err := fetchToolDefinitions(ctx, client, endpoint)
if err != nil {
errs = append(errs, errors.Wrapf(err, "failed to load tools from %s", endpoint))
continue
}
r.register(endpoint, tools)
}

if len(errs) > 0 {
return errors.Append(nil, errs...)
tools, err := fetchToolDefinitions(ctx, client, MCPURLPath)
if err != nil {
return err
}
r.tools = tools
return nil
}

// register associates a collection of tools with the given endpoint
func (r *ToolRegistry) register(endpoint string, tools map[string]*ToolDef) {
for name, def := range tools {
r.tools[name] = def
r.endpoints[name] = endpoint
}
}

// Get returns the tool definition for the given name
func (r *ToolRegistry) Get(name string) (*ToolDef, bool) {
tool, ok := r.tools[name]
Expand All @@ -62,8 +40,7 @@ func (r *ToolRegistry) Get(name string) (*ToolDef, bool) {
// CallTool calls the given tool with the given arguments. It constructs the Tool request and decodes the Tool response
func (r *ToolRegistry) CallTool(ctx context.Context, client api.Client, name string, args map[string]any) (map[string]json.RawMessage, error) {
tool := r.tools[name]
endpoint := r.endpoints[name]
resp, err := doToolCall(ctx, client, endpoint, tool.RawName, args)
resp, err := doToolCall(ctx, client, MCPURLPath, tool.RawName, args)
if err != nil {
return nil, err
}
Expand Down
Loading