diff --git a/go/sdk/variants/dispatch.go b/go/sdk/variants/dispatch.go index ef1440e..3a7b758 100644 --- a/go/sdk/variants/dispatch.go +++ b/go/sdk/variants/dispatch.go @@ -29,14 +29,10 @@ func (d *dispatcher) handle(ctx context.Context, method string, req mcp.Request, switch method { case "tools/list", "resources/list", "prompts/list", "resources/templates/list": return d.handleList(ctx, method, req) - case "tools/call", "resources/read", "prompts/get": - return d.handleCall(ctx, method, req) - case "resources/subscribe": - return d.handleSubscribe(ctx, req) - case "resources/unsubscribe": - return d.handleUnsubscribe(ctx, req) - case "completion/complete": - return d.handleCompletion(ctx, req) + case "tools/call", "resources/read", "prompts/get", + "resources/subscribe", "resources/unsubscribe", + "completion/complete": + return d.handleDirect(ctx, method, req) default: return next(ctx, method, req) } @@ -66,16 +62,48 @@ func (d *dispatcher) createInvalidVariantError(ctx context.Context, requestedVar } } +// isNilInterface checks if v is nil or a typed-nil (a nil pointer wrapped in an interface). +func isNilInterface(v any) bool { + if v == nil { + return true + } + rv := reflect.ValueOf(v) + return rv.Kind() == reflect.Ptr && rv.IsNil() +} + +// The reflect-based field access (cursor unwrap/wrap, metadata injection) calls +// Elem() to dereference pointers, which panics on non-pointer types. Even without +// the panic, mutations on a value type would not propagate back to the original +// request/result. The current SDK satisfies this: all Params and Result types use +// pointer receivers for their interface marker methods (isParams/isResult), so only +// pointer types can implement the interfaces. These checks guard against future +// SDK changes. +var ( + errParamsNotPointer = errors.New("variants: expected pointer type for Params, got value type") + errResultNotPointer = errors.New("variants: expected pointer type for Result, got value type") +) + +// Compile-time assertions: ensure list param/result types have the Cursor and +// NextCursor fields that handleList accesses via reflection. +var ( + _ = mcp.ListToolsParams{}.Cursor + _ = mcp.ListResourcesParams{}.Cursor + _ = mcp.ListPromptsParams{}.Cursor + _ = mcp.ListResourceTemplatesParams{}.Cursor + + _ = mcp.ListToolsResult{}.NextCursor + _ = mcp.ListResourcesResult{}.NextCursor + _ = mcp.ListPromptsResult{}.NextCursor + _ = mcp.ListResourceTemplatesResult{}.NextCursor +) + // variantIDFromMeta extracts the variant ID from the request's _meta field. // Returns empty string if no variant is specified. Guards against typed-nil // params (e.g. (*ListToolsParams)(nil) wrapped in the mcp.Params interface) // which the SDK can produce for requests with no parameters. func variantIDFromMeta(req mcp.Request) string { params := req.GetParams() - if params == nil { - return "" - } - if v := reflect.ValueOf(params); v.Kind() == reflect.Ptr && v.IsNil() { + if isNilInterface(params) { return "" } meta := params.GetMeta() @@ -160,7 +188,7 @@ func enrichError(err error, variantID string) error { // List methods // --------------------------------------------------------------------------- -// handleList handles list methods by forwarding to the appropriate variant. +// handleList handles list methods using the generic backend session call method. // Implements cursor scoping per SEP-2053: unwraps incoming cursors and wraps outgoing cursors. func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { conn, err := d.getConnection(ctx, req) @@ -171,230 +199,70 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ backendSession := conn.backendSession variantID := backendSession.variantID params := req.GetParams() - extra := req.GetExtra() - switch method { - case "tools/list": - p, _ := params.(*mcp.ListToolsParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - } - result, err := backendSession.ListTools(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) - } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - case "resources/list": - p, _ := params.(*mcp.ListResourcesParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - } - result, err := backendSession.ListResources(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) - } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - case "prompts/list": - p, _ := params.(*mcp.ListPromptsParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - } - result, err := backendSession.ListPrompts(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) + // Inject variant metadata and handle cursor unwrapping (guard against typed-nil params) + if !isNilInterface(params) { + if reflect.ValueOf(params).Kind() != reflect.Ptr { + return nil, errParamsNotPointer } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - case "resources/templates/list": - p, _ := params.(*mcp.ListResourceTemplatesParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - } - result, err := backendSession.ListResourceTemplates(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) - } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - default: - return nil, errors.New("unsupported list method: " + method) - } -} - -// --------------------------------------------------------------------------- -// Call methods -// --------------------------------------------------------------------------- - -// handleCall handles call methods (tools/call, resources/read, prompts/get). -func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - conn, err := d.getConnection(ctx, req) - if err != nil { - return nil, err - } - - backendSession := conn.backendSession - variantID := backendSession.variantID - params := req.GetParams() - extra := req.GetExtra() - var result mcp.Result + injectVariantMeta(params, variantID) - switch method { - case "tools/call": - raw, _ := params.(*mcp.CallToolParamsRaw) - if raw == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid tools/call params", - } - } - injectVariantMeta(raw, variantID) - result, err = backendSession.CallTool(ctx, raw, extra) - case "resources/read": - p, _ := params.(*mcp.ReadResourceParams) - if p == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid resources/read params", + if f := reflect.ValueOf(params).Elem().FieldByName("Cursor"); f.IsValid() && f.String() != "" { + innerCursor, err := unwrapCursor(f.String(), variantID) + if err != nil { + return nil, err } + f.SetString(innerCursor) } - injectVariantMeta(p, variantID) - result, err = backendSession.ReadResource(ctx, p, extra) - case "prompts/get": - p, _ := params.(*mcp.GetPromptParams) - if p == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid prompts/get params", - } - } - injectVariantMeta(p, variantID) - result, err = backendSession.GetPrompt(ctx, p, extra) - default: - return nil, errors.New("unsupported call method: " + method) } + result, err := backendSession.handleReceive(ctx, method, req) if err != nil { return nil, enrichError(err, variantID) } - return result, nil -} - -// --------------------------------------------------------------------------- -// Subscription methods -// --------------------------------------------------------------------------- - -// handleSubscribe handles resources/subscribe. -func (d *dispatcher) handleSubscribe(ctx context.Context, req mcp.Request) (mcp.Result, error) { - conn, err := d.getConnection(ctx, req) - if err != nil { - return nil, err + if isNilInterface(result) { + return nil, nil } - backendSession := conn.backendSession - params, _ := req.GetParams().(*mcp.SubscribeParams) - if params == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid resources/subscribe params", - } + if reflect.ValueOf(result).Kind() != reflect.Ptr { + return nil, errResultNotPointer } - injectVariantMeta(params, backendSession.variantID) - if err := backendSession.Subscribe(ctx, params, req.GetExtra()); err != nil { - return nil, enrichError(err, backendSession.variantID) - } - return nil, nil -} - -// handleUnsubscribe handles resources/unsubscribe. -// Per SEP-2053: "Servers MUST continue to accept resources/unsubscribe for -// existing subscription ids even if the underlying resource is no longer available." -func (d *dispatcher) handleUnsubscribe(ctx context.Context, req mcp.Request) (mcp.Result, error) { - conn, err := d.getConnection(ctx, req) - if err != nil { - return nil, err + if f := reflect.ValueOf(result).Elem().FieldByName("NextCursor"); f.IsValid() && f.String() != "" { + f.SetString(wrapCursor(f.String(), variantID)) } - backendSession := conn.backendSession - params, _ := req.GetParams().(*mcp.UnsubscribeParams) - if params == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid resources/unsubscribe params", - } - } - injectVariantMeta(params, backendSession.variantID) - if err := backendSession.Unsubscribe(ctx, params, req.GetExtra()); err != nil { - return nil, enrichError(err, backendSession.variantID) - } - return nil, nil + return result, nil } // --------------------------------------------------------------------------- -// Completion +// Simple methods (no pagination) // --------------------------------------------------------------------------- -// handleCompletion handles completion/complete. -func (d *dispatcher) handleCompletion(ctx context.Context, req mcp.Request) (mcp.Result, error) { +// handleDirect handles all simple methods (call, subscribe, unsubscribe, completion) +// that don't require special cursor handling. +func (d *dispatcher) handleDirect(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { conn, err := d.getConnection(ctx, req) if err != nil { return nil, err } backendSession := conn.backendSession - params, _ := req.GetParams().(*mcp.CompleteParams) - if params == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid completion/complete params", + variantID := backendSession.variantID + params := req.GetParams() + + // Inject variant metadata (guard against typed-nil params) + if !isNilInterface(params) { + if reflect.ValueOf(params).Kind() != reflect.Ptr { + return nil, errParamsNotPointer } + injectVariantMeta(params, variantID) } - injectVariantMeta(params, backendSession.variantID) - result, err := backendSession.Complete(ctx, params, req.GetExtra()) + + result, err := backendSession.handleReceive(ctx, method, req) if err != nil { - return nil, enrichError(err, backendSession.variantID) + return nil, enrichError(err, variantID) } + return result, nil } diff --git a/go/sdk/variants/dispatch_test.go b/go/sdk/variants/dispatch_test.go new file mode 100644 index 0000000..b8e0a57 --- /dev/null +++ b/go/sdk/variants/dispatch_test.go @@ -0,0 +1,684 @@ +// Copyright 2025 The MCP Variants Authors. All rights reserved. +// Use of this source code is governed by a Apache-2.0 +// license that can be found in the LICENSE file. + +package variants + +import ( + "context" + "errors" + "reflect" + "testing" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// isNilInterface +// --------------------------------------------------------------------------- + +func TestIsNilInterface(t *testing.T) { + tests := []struct { + name string + val any + want bool + }{ + {"untyped nil", nil, true}, + {"typed nil pointer", (*mcp.ListToolsParams)(nil), true}, + {"non-nil pointer", &mcp.ListToolsParams{}, false}, + {"non-pointer value", 42, false}, + {"string value", "hello", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, isNilInterface(tt.val)) + }) + } +} + +// --------------------------------------------------------------------------- +// handleList +// --------------------------------------------------------------------------- + +// newTestDispatcher builds a dispatcher with a single variant backed by a +// custom mcpMethodHandler. The handler receives the request exactly as the +// dispatcher sends it, so tests can inspect params/session mutations. +func newTestDispatcher(variantID string, handler mcp.MethodHandler) *dispatcher { + return &dispatcher{ + server: NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}). + WithVariant(ServerVariant{ID: variantID, Status: Stable}, mcp.NewServer(&mcp.Implementation{Name: "inner", Version: "v0.0.1"}, nil), 0), + connections: map[string]*innerConnection{ + variantID: { + backendSession: &backendSession{ + variantID: variantID, + mcpMethodHandler: handler, + }, + }, + }, + } +} + +func TestHandleList_NilResult(t *testing.T) { + const variantID = "test-variant" + + tests := []struct { + name string + handler func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) + }{ + { + name: "untyped nil", + handler: func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return nil, nil + }, + }, + { + name: "typed nil", + handler: func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return (*mcp.ListToolsResult)(nil), nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := newTestDispatcher(variantID, tt.handler) + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + }, + } + + result, err := d.handleList(context.Background(), "tools/list", req) + require.NoError(t, err) + assert.Nil(t, result) + }) + } +} + +func TestHandleList_CursorRoundTrip(t *testing.T) { + const variantID = "v1" + const innerCursor = "page-2-token" + + tests := []struct { + name string + method string + req func(wrappedCursor string) mcp.Request + result func() mcp.Result // result returned by backend with NextCursor set + }{ + { + name: "tools/list", + method: "tools/list", + req: func(c string) mcp.Request { + return &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: c, + }, + } + }, + result: func() mcp.Result { + return &mcp.ListToolsResult{NextCursor: "next-inner"} + }, + }, + { + name: "resources/list", + method: "resources/list", + req: func(c string) mcp.Request { + return &mcp.ListResourcesRequest{ + Params: &mcp.ListResourcesParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: c, + }, + } + }, + result: func() mcp.Result { + return &mcp.ListResourcesResult{NextCursor: "next-inner"} + }, + }, + { + name: "prompts/list", + method: "prompts/list", + req: func(c string) mcp.Request { + return &mcp.ListPromptsRequest{ + Params: &mcp.ListPromptsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: c, + }, + } + }, + result: func() mcp.Result { + return &mcp.ListPromptsResult{NextCursor: "next-inner"} + }, + }, + { + name: "resources/templates/list", + method: "resources/templates/list", + req: func(c string) mcp.Request { + return &mcp.ListResourceTemplatesRequest{ + Params: &mcp.ListResourceTemplatesParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: c, + }, + } + }, + result: func() mcp.Result { + return &mcp.ListResourceTemplatesResult{NextCursor: "next-inner"} + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wrappedCursor := wrapCursor(innerCursor, variantID) + var receivedCursor string + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + // Capture what cursor the backend actually receives. + params := req.GetParams() + f := reflect.ValueOf(params).Elem().FieldByName("Cursor") + receivedCursor = f.String() + return tt.result(), nil + }) + + result, err := d.handleList(context.Background(), tt.method, tt.req(wrappedCursor)) + require.NoError(t, err) + + // Backend should have received the inner (unwrapped) cursor. + assert.Equal(t, innerCursor, receivedCursor, "cursor should be unwrapped before reaching backend") + + // Result's NextCursor should now be wrapped. + nextCursor := reflect.ValueOf(result).Elem().FieldByName("NextCursor").String() + unwrapped, err := unwrapCursor(nextCursor, variantID) + require.NoError(t, err) + assert.Equal(t, "next-inner", unwrapped, "NextCursor should be wrapped with variant ID") + }) + } +} + +func TestHandleList_NoCursor(t *testing.T) { + const variantID = "v1" + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return &mcp.ListToolsResult{}, nil + }) + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + }, + } + + result, err := d.handleList(context.Background(), "tools/list", req) + require.NoError(t, err) + + // No NextCursor should be set when backend returns empty cursor. + r := result.(*mcp.ListToolsResult) + assert.Empty(t, r.NextCursor) +} + +func TestHandleList_InvalidCursor(t *testing.T) { + const variantID = "v1" + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + t.Fatal("backend should not be called with invalid cursor") + return nil, nil + }) + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: "not-valid-base64!@#$", + }, + } + + _, err := d.handleList(context.Background(), "tools/list", req) + require.Error(t, err) + + var jErr *jsonrpc.Error + require.True(t, errors.As(err, &jErr)) + assert.EqualValues(t, jsonrpc.CodeInvalidParams, jErr.Code) +} + +func TestHandleList_CrossVariantCursor(t *testing.T) { + const variantID = "v1" + otherVariantCursor := wrapCursor("page2", "other-variant") + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + t.Fatal("backend should not be called with cross-variant cursor") + return nil, nil + }) + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: otherVariantCursor, + }, + } + + _, err := d.handleList(context.Background(), "tools/list", req) + require.Error(t, err) + + var jErr *jsonrpc.Error + require.True(t, errors.As(err, &jErr)) + assert.EqualValues(t, jsonrpc.CodeInvalidParams, jErr.Code) + assert.Contains(t, jErr.Message, "Cursor invalid for requested variant") +} + +func TestHandleList_ErrorEnrichment(t *testing.T) { + const variantID = "v1" + + backendErr := &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "unknown tool", + } + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return nil, backendErr + }) + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + }, + } + + _, err := d.handleList(context.Background(), "tools/list", req) + require.Error(t, err) + + var jErr *jsonrpc.Error + require.True(t, errors.As(err, &jErr)) + assert.Contains(t, string(jErr.Data), `"activeVariant"`) + assert.Contains(t, string(jErr.Data), variantID) +} + +func TestHandleList_NilParams(t *testing.T) { + const variantID = "v1" + called := false + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + called = true + return &mcp.ListToolsResult{}, nil + }) + + // Typed-nil params — the SDK produces these for parameterless requests. + req := &mcp.ListToolsRequest{ + Params: (*mcp.ListToolsParams)(nil), + } + + _, err := d.handleList(context.Background(), "tools/list", req) + require.NoError(t, err) + assert.True(t, called, "backend should still be called with nil params") +} + +// --------------------------------------------------------------------------- +// handleDirect +// --------------------------------------------------------------------------- + +func TestHandleDirect_MethodRouting(t *testing.T) { + const variantID = "v1" + + tests := []struct { + name string + method string + req mcp.Request + }{ + { + name: "tools/call", + method: "tools/call", + req: &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Name: "my-tool", + }, + }, + }, + { + name: "resources/read", + method: "resources/read", + req: &mcp.ReadResourceRequest{ + Params: &mcp.ReadResourceParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + URI: "file:///test", + }, + }, + }, + { + name: "prompts/get", + method: "prompts/get", + req: &mcp.GetPromptRequest{ + Params: &mcp.GetPromptParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Name: "my-prompt", + }, + }, + }, + { + name: "resources/subscribe", + method: "resources/subscribe", + req: &mcp.SubscribeRequest{ + Params: &mcp.SubscribeParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + URI: "file:///watch", + }, + }, + }, + { + name: "resources/unsubscribe", + method: "resources/unsubscribe", + req: &mcp.UnsubscribeRequest{ + Params: &mcp.UnsubscribeParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + URI: "file:///watch", + }, + }, + }, + { + name: "completion/complete", + method: "completion/complete", + req: &mcp.CompleteRequest{ + Params: &mcp.CompleteParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var receivedMethod string + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + receivedMethod = method + return nil, nil + }) + + _, err := d.handleDirect(context.Background(), tt.method, tt.req) + require.NoError(t, err) + assert.Equal(t, tt.method, receivedMethod, "backend should receive the correct method") + }) + } +} + +func TestHandleDirect_MetaInjection(t *testing.T) { + const variantID = "v1" + + var receivedMeta mcp.Meta + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + receivedMeta = req.GetParams().GetMeta() + return &mcp.CallToolResult{}, nil + }) + + // Start without variant meta — handleDirect should inject it. + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "my-tool", + }, + } + + _, err := d.handleDirect(context.Background(), "tools/call", req) + require.NoError(t, err) + assert.Equal(t, variantID, receivedMeta[metaKeyVariant], "variant meta should be injected into params") +} + +func TestHandleDirect_ErrorEnrichment(t *testing.T) { + const variantID = "v1" + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "tool not found", + } + }) + + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Name: "nonexistent", + }, + } + + _, err := d.handleDirect(context.Background(), "tools/call", req) + require.Error(t, err) + + var jErr *jsonrpc.Error + require.True(t, errors.As(err, &jErr)) + assert.Contains(t, string(jErr.Data), `"activeVariant"`) + assert.Contains(t, string(jErr.Data), variantID) +} + +func TestHandleDirect_NonEnrichableError(t *testing.T) { + const variantID = "v1" + plainErr := errors.New("internal error") + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return nil, plainErr + }) + + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Name: "my-tool", + }, + } + + _, err := d.handleDirect(context.Background(), "tools/call", req) + assert.Equal(t, plainErr, err, "non-jsonrpc errors should pass through unmodified") +} + +func TestHandleDirect_NilParams(t *testing.T) { + const variantID = "v1" + called := false + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + called = true + return nil, nil + }) + + req := &mcp.SubscribeRequest{ + Params: (*mcp.SubscribeParams)(nil), + } + + _, err := d.handleDirect(context.Background(), "resources/subscribe", req) + require.NoError(t, err) + assert.True(t, called, "backend should still be called when params are typed-nil") +} + +// --------------------------------------------------------------------------- +// handle (top-level router) +// --------------------------------------------------------------------------- + +func TestHandle_UnknownMethodPassthrough(t *testing.T) { + const variantID = "v1" + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + t.Fatal("dispatcher handler should not be called for unknown methods") + return nil, nil + }) + + nextCalled := false + next := func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + nextCalled = true + return nil, nil + } + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{}, + } + + _, err := d.handle(context.Background(), "custom/method", req, next) + require.NoError(t, err) + assert.True(t, nextCalled, "unknown methods should fall through to next") +} + +func TestHandle_RoutesToCorrectHandler(t *testing.T) { + const variantID = "v1" + + tests := []struct { + name string + method string + req mcp.Request + }{ + {"list routes to handleList", "tools/list", &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{Meta: mcp.Meta{metaKeyVariant: variantID}}, + }}, + {"call routes to handleDirect", "tools/call", &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{Meta: mcp.Meta{metaKeyVariant: variantID}, Name: "t"}, + }}, + {"subscribe routes to handleDirect", "resources/subscribe", &mcp.SubscribeRequest{ + Params: &mcp.SubscribeParams{Meta: mcp.Meta{metaKeyVariant: variantID}, URI: "u"}, + }}, + {"unsubscribe routes to handleDirect", "resources/unsubscribe", &mcp.UnsubscribeRequest{ + Params: &mcp.UnsubscribeParams{Meta: mcp.Meta{metaKeyVariant: variantID}, URI: "u"}, + }}, + {"complete routes to handleDirect", "completion/complete", &mcp.CompleteRequest{ + Params: &mcp.CompleteParams{Meta: mcp.Meta{metaKeyVariant: variantID}}, + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var receivedMethod string + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + receivedMethod = method + if method == "tools/list" { + return &mcp.ListToolsResult{}, nil + } + return nil, nil + }) + + next := func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + t.Fatalf("next should not be called for method %s", method) + return nil, nil + } + + _, err := d.handle(context.Background(), tt.method, tt.req, next) + require.NoError(t, err) + assert.Equal(t, tt.method, receivedMethod) + }) + } +} + +// --------------------------------------------------------------------------- +// handleReceive (session.go) +// --------------------------------------------------------------------------- + +func TestHandleReceive_SessionInjection(t *testing.T) { + originalSession := &mcp.ServerSession{} + targetSession := &mcp.ServerSession{} + + var receivedSession *mcp.ServerSession + bs := &backendSession{ + variantID: "v1", + serverSession: targetSession, + mcpMethodHandler: func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + receivedSession = req.GetSession().(*mcp.ServerSession) + return &mcp.ListToolsResult{}, nil + }, + } + + req := &mcp.ListToolsRequest{ + Session: originalSession, + Params: &mcp.ListToolsParams{}, + } + + _, err := bs.handleReceive(context.Background(), "tools/list", req) + require.NoError(t, err) + assert.Same(t, targetSession, receivedSession, "handleReceive should replace Session with inner server session") + assert.Same(t, originalSession, req.Session, "handleReceive should not mutate the original request") +} + +// --------------------------------------------------------------------------- +// getConnection +// --------------------------------------------------------------------------- + +func TestGetConnection_InvalidVariant(t *testing.T) { + d := newTestDispatcher("v1", nil) + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: "nonexistent"}, + }, + } + + _, err := d.getConnection(context.Background(), req) + require.Error(t, err) + + var jErr *jsonrpc.Error + require.True(t, errors.As(err, &jErr)) + assert.EqualValues(t, jsonrpc.CodeInvalidParams, jErr.Code) + assert.Contains(t, jErr.Message, "Invalid server variant") + assert.Contains(t, string(jErr.Data), "nonexistent") +} + +func TestGetConnection_DefaultVariant(t *testing.T) { + d := newTestDispatcher("v1", nil) + + // No variant in meta — should fall back to default (first-ranked). + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{}, + } + + conn, err := d.getConnection(context.Background(), req) + require.NoError(t, err) + assert.Equal(t, "v1", conn.backendSession.variantID) +} + +// --------------------------------------------------------------------------- +// enrichError +// --------------------------------------------------------------------------- + +func TestEnrichError(t *testing.T) { + tests := []struct { + name string + err error + wantEnrich bool + }{ + { + name: "InvalidParams enriched", + err: &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "bad param", + }, + wantEnrich: true, + }, + { + name: "MethodNotFound enriched", + err: &jsonrpc.Error{ + Code: jsonrpc.CodeMethodNotFound, + Message: "no method", + }, + wantEnrich: true, + }, + { + name: "InternalError not enriched", + err: &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "server error", + }, + wantEnrich: false, + }, + { + name: "plain error not enriched", + err: errors.New("some error"), + wantEnrich: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := enrichError(tt.err, "v1") + + var jErr *jsonrpc.Error + if errors.As(result, &jErr) && tt.wantEnrich { + assert.Contains(t, string(jErr.Data), `"activeVariant"`) + assert.Contains(t, string(jErr.Data), `"v1"`) + } else if !tt.wantEnrich { + if errors.As(result, &jErr) { + // Should NOT have activeVariant + assert.NotContains(t, string(jErr.Data), `"activeVariant"`) + } + } + }) + } +} diff --git a/go/sdk/variants/session.go b/go/sdk/variants/session.go index add4563..d2f0449 100644 --- a/go/sdk/variants/session.go +++ b/go/sdk/variants/session.go @@ -6,7 +6,8 @@ package variants import ( "context" - "fmt" + "errors" + "reflect" "sync" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -42,110 +43,34 @@ type backendSession struct { mcpMethodHandler mcp.MethodHandler } -func (s *backendSession) ListTools(ctx context.Context, p *mcp.ListToolsParams, extra *mcp.RequestExtra) (*mcp.ListToolsResult, error) { - result, err := s.mcpMethodHandler(ctx, "tools/list", &mcp.ListToolsRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListToolsResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for tools/list", result) - } - return r, nil -} - -func (s *backendSession) ListResources(ctx context.Context, p *mcp.ListResourcesParams, extra *mcp.RequestExtra) (*mcp.ListResourcesResult, error) { - result, err := s.mcpMethodHandler(ctx, "resources/list", &mcp.ListResourcesRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListResourcesResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for resources/list", result) - } - return r, nil -} - -func (s *backendSession) ListPrompts(ctx context.Context, p *mcp.ListPromptsParams, extra *mcp.RequestExtra) (*mcp.ListPromptsResult, error) { - result, err := s.mcpMethodHandler(ctx, "prompts/list", &mcp.ListPromptsRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListPromptsResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for prompts/list", result) - } - return r, nil -} - -func (s *backendSession) ListResourceTemplates(ctx context.Context, p *mcp.ListResourceTemplatesParams, extra *mcp.RequestExtra) (*mcp.ListResourceTemplatesResult, error) { - result, err := s.mcpMethodHandler(ctx, "resources/templates/list", &mcp.ListResourceTemplatesRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListResourceTemplatesResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for resources/templates/list", result) - } - return r, nil -} - -func (s *backendSession) CallTool(ctx context.Context, p *mcp.CallToolParamsRaw, extra *mcp.RequestExtra) (*mcp.CallToolResult, error) { - result, err := s.mcpMethodHandler(ctx, "tools/call", &mcp.CallToolRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.CallToolResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for tools/call", result) - } - return r, nil -} - -func (s *backendSession) ReadResource(ctx context.Context, p *mcp.ReadResourceParams, extra *mcp.RequestExtra) (*mcp.ReadResourceResult, error) { - result, err := s.mcpMethodHandler(ctx, "resources/read", &mcp.ReadResourceRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ReadResourceResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for resources/read", result) - } - return r, nil -} - -func (s *backendSession) GetPrompt(ctx context.Context, p *mcp.GetPromptParams, extra *mcp.RequestExtra) (*mcp.GetPromptResult, error) { - result, err := s.mcpMethodHandler(ctx, "prompts/get", &mcp.GetPromptRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.GetPromptResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for prompts/get", result) - } - return r, nil -} - -func (s *backendSession) Subscribe(ctx context.Context, p *mcp.SubscribeParams, extra *mcp.RequestExtra) error { - _, err := s.mcpMethodHandler(ctx, "resources/subscribe", &mcp.SubscribeRequest{Session: s.serverSession, Params: p, Extra: extra}) - return err -} - -func (s *backendSession) Unsubscribe(ctx context.Context, p *mcp.UnsubscribeParams, extra *mcp.RequestExtra) error { - _, err := s.mcpMethodHandler(ctx, "resources/unsubscribe", &mcp.UnsubscribeRequest{Session: s.serverSession, Params: p, Extra: extra}) - return err -} - -func (s *backendSession) Complete(ctx context.Context, p *mcp.CompleteParams, extra *mcp.RequestExtra) (*mcp.CompleteResult, error) { - result, err := s.mcpMethodHandler(ctx, "completion/complete", &mcp.CompleteRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.CompleteResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for completion/complete", result) - } - return r, nil +// handleReceive invokes mcpMethodHandler for any MCP method by modifying the request's +// Session field to point to the inner server session. This replaces the explicit +// per-method functions (ListTools, CallTool, etc.) with a single generic handler. +// +// The dispatcher is responsible for modifying params (metadata injection, +// cursor unwrapping) before calling this method. +func (s *backendSession) handleReceive(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + // Shallow-copy the concrete request struct so we don't mutate the caller's object + // while replacing the Session field with our inner server session. + // We can't use a wrapper (like the sending side's sessionSwappedRequest) because + // the SDK's receiving handler does type assertions on the concrete request type. + reqVal := reflect.ValueOf(req) + if reqVal.Kind() != reflect.Ptr { + return nil, errors.New("variants: expected pointer to request struct") + } + + // Allocate a new struct of the same concrete type and shallow-copy all fields. + copyPtr := reflect.New(reqVal.Elem().Type()) + copyPtr.Elem().Set(reqVal.Elem()) + + // Set Session on the copy to point to the inner server session. + sessionField := copyPtr.Elem().FieldByName("Session") + if !sessionField.IsValid() || !sessionField.CanSet() { + return nil, errors.New("variants: request type missing settable Session field") + } + sessionField.Set(reflect.ValueOf(s.serverSession)) + + return s.mcpMethodHandler(ctx, method, copyPtr.Interface().(mcp.Request)) } // sessionState holds all per-session state for one front client.