From fb633d4e4fda462d9b1394537f250f623513183c Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Mon, 23 Mar 2026 12:00:12 +0000 Subject: [PATCH 1/7] Refactor receiving method redirection to reduce duplications. --- go/sdk/variants/dispatch.go | 261 ++++++++---------------------------- go/sdk/variants/session.go | 119 +++------------- 2 files changed, 73 insertions(+), 307 deletions(-) diff --git a/go/sdk/variants/dispatch.go b/go/sdk/variants/dispatch.go index ef1440e..3bc896f 100644 --- a/go/sdk/variants/dispatch.go +++ b/go/sdk/variants/dispatch.go @@ -29,16 +29,8 @@ func (d *dispatcher) handle(ctx context.Context, method string, req mcp.Request, switch method { case "tools/list", "resources/list", "prompts/list", "resources/templates/list": return d.handleList(ctx, method, req) - case "tools/call", "resources/read", "prompts/get": - return d.handleCall(ctx, method, req) - case "resources/subscribe": - return d.handleSubscribe(ctx, req) - case "resources/unsubscribe": - return d.handleUnsubscribe(ctx, req) - case "completion/complete": - return d.handleCompletion(ctx, req) default: - return next(ctx, method, req) + return d.handleReceiveRedirect(ctx, method, req) } } @@ -66,16 +58,23 @@ func (d *dispatcher) createInvalidVariantError(ctx context.Context, requestedVar } } +// isParamsNil checks if params is nil or a typed-nil (a nil pointer wrapped in an interface). +// The SDK can produce typed-nil params for requests with no parameters. +func isParamsNil(params mcp.Params) bool { + if params == nil { + return true + } + v := reflect.ValueOf(params) + return v.Kind() == reflect.Ptr && v.IsNil() +} + // variantIDFromMeta extracts the variant ID from the request's _meta field. // Returns empty string if no variant is specified. Guards against typed-nil // params (e.g. (*ListToolsParams)(nil) wrapped in the mcp.Params interface) // which the SDK can produce for requests with no parameters. func variantIDFromMeta(req mcp.Request) string { params := req.GetParams() - if params == nil { - return "" - } - if v := reflect.ValueOf(params); v.Kind() == reflect.Ptr && v.IsNil() { + if isParamsNil(params) { return "" } meta := params.GetMeta() @@ -160,7 +159,20 @@ func enrichError(err error, variantID string) error { // List methods // --------------------------------------------------------------------------- -// handleList handles list methods by forwarding to the appropriate variant. +// cursorParams is implemented by all list params types that support pagination. +type cursorParams interface { + mcp.Params + GetCursor() string + SetCursor(string) +} + +// cursorResult is implemented by all list result types that support pagination. +type cursorResult interface { + GetNextCursor() string + SetNextCursor(string) +} + +// handleList handles list methods using the generic backend session call method. // Implements cursor scoping per SEP-2053: unwraps incoming cursors and wraps outgoing cursors. func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { conn, err := d.getConnection(ctx, req) @@ -171,230 +183,67 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ backendSession := conn.backendSession variantID := backendSession.variantID params := req.GetParams() - extra := req.GetExtra() - switch method { - case "tools/list": - p, _ := params.(*mcp.ListToolsParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - } - result, err := backendSession.ListTools(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) - } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - case "resources/list": - p, _ := params.(*mcp.ListResourcesParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - } - result, err := backendSession.ListResources(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) - } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - case "prompts/list": - p, _ := params.(*mcp.ListPromptsParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - } - result, err := backendSession.ListPrompts(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) - } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - case "resources/templates/list": - p, _ := params.(*mcp.ListResourceTemplatesParams) - if p != nil { - injectVariantMeta(p, variantID) - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) + // Inject variant metadata and handle cursor unwrapping (guard against typed-nil params) + if !isParamsNil(params) { + injectVariantMeta(params, variantID) + + // Handle cursor unwrapping if params support it + if cp, ok := params.(cursorParams); ok { + if cursor := cp.GetCursor(); cursor != "" { + innerCursor, err := unwrapCursor(cursor, variantID) if err != nil { return nil, err } - p.Cursor = innerCursor + cp.SetCursor(innerCursor) } } - result, err := backendSession.ListResourceTemplates(ctx, p, extra) - if err != nil { - return nil, enrichError(err, variantID) - } - if result != nil && result.NextCursor != "" { - result.NextCursor = wrapCursor(result.NextCursor, variantID) - } - return result, nil - - default: - return nil, errors.New("unsupported list method: " + method) } -} -// --------------------------------------------------------------------------- -// Call methods -// --------------------------------------------------------------------------- - -// handleCall handles call methods (tools/call, resources/read, prompts/get). -func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - conn, err := d.getConnection(ctx, req) + // Generic dispatch - pass entire request object + result, err := backendSession.handleReceive(ctx, method, req) if err != nil { - return nil, err + return nil, enrichError(err, variantID) } - backendSession := conn.backendSession - variantID := backendSession.variantID - params := req.GetParams() - extra := req.GetExtra() - var result mcp.Result - - switch method { - case "tools/call": - raw, _ := params.(*mcp.CallToolParamsRaw) - if raw == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid tools/call params", - } - } - injectVariantMeta(raw, variantID) - result, err = backendSession.CallTool(ctx, raw, extra) - case "resources/read": - p, _ := params.(*mcp.ReadResourceParams) - if p == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid resources/read params", - } - } - injectVariantMeta(p, variantID) - result, err = backendSession.ReadResource(ctx, p, extra) - case "prompts/get": - p, _ := params.(*mcp.GetPromptParams) - if p == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid prompts/get params", - } + // Handle cursor wrapping in result + if cr, ok := result.(cursorResult); ok { + if cursor := cr.GetNextCursor(); cursor != "" { + cr.SetNextCursor(wrapCursor(cursor, variantID)) } - injectVariantMeta(p, variantID) - result, err = backendSession.GetPrompt(ctx, p, extra) - default: - return nil, errors.New("unsupported call method: " + method) } - if err != nil { - return nil, enrichError(err, variantID) - } return result, nil } // --------------------------------------------------------------------------- -// Subscription methods +// Simple methods (no pagination) // --------------------------------------------------------------------------- -// handleSubscribe handles resources/subscribe. -func (d *dispatcher) handleSubscribe(ctx context.Context, req mcp.Request) (mcp.Result, error) { +// handleReceiveRedirect handles all simple methods (call, subscribe, unsubscribe, completion) +// that don't require special cursor handling. This consolidates what were previously +// separate handlers for handleCall, handleSubscribe, handleUnsubscribe, and handleCompletion. +func (d *dispatcher) handleReceiveRedirect(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { conn, err := d.getConnection(ctx, req) if err != nil { return nil, err } backendSession := conn.backendSession - params, _ := req.GetParams().(*mcp.SubscribeParams) - if params == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid resources/subscribe params", - } - } - injectVariantMeta(params, backendSession.variantID) - if err := backendSession.Subscribe(ctx, params, req.GetExtra()); err != nil { - return nil, enrichError(err, backendSession.variantID) - } - return nil, nil -} - -// handleUnsubscribe handles resources/unsubscribe. -// Per SEP-2053: "Servers MUST continue to accept resources/unsubscribe for -// existing subscription ids even if the underlying resource is no longer available." -func (d *dispatcher) handleUnsubscribe(ctx context.Context, req mcp.Request) (mcp.Result, error) { - conn, err := d.getConnection(ctx, req) - if err != nil { - return nil, err - } + variantID := backendSession.variantID + params := req.GetParams() - backendSession := conn.backendSession - params, _ := req.GetParams().(*mcp.UnsubscribeParams) - if params == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid resources/unsubscribe params", - } - } - injectVariantMeta(params, backendSession.variantID) - if err := backendSession.Unsubscribe(ctx, params, req.GetExtra()); err != nil { - return nil, enrichError(err, backendSession.variantID) + // Inject variant metadata (guard against typed-nil params) + if !isParamsNil(params) { + injectVariantMeta(params, variantID) } - return nil, nil -} - -// --------------------------------------------------------------------------- -// Completion -// --------------------------------------------------------------------------- -// handleCompletion handles completion/complete. -func (d *dispatcher) handleCompletion(ctx context.Context, req mcp.Request) (mcp.Result, error) { - conn, err := d.getConnection(ctx, req) + // Generic dispatch - pass entire request object + result, err := backendSession.handleReceive(ctx, method, req) if err != nil { - return nil, err + return nil, enrichError(err, variantID) } - backendSession := conn.backendSession - params, _ := req.GetParams().(*mcp.CompleteParams) - if params == nil { - return nil, &jsonrpc.Error{ - Code: jsonrpc.CodeInvalidParams, - Message: "missing or invalid completion/complete params", - } - } - injectVariantMeta(params, backendSession.variantID) - result, err := backendSession.Complete(ctx, params, req.GetExtra()) - if err != nil { - return nil, enrichError(err, backendSession.variantID) - } return result, nil } diff --git a/go/sdk/variants/session.go b/go/sdk/variants/session.go index add4563..3e57368 100644 --- a/go/sdk/variants/session.go +++ b/go/sdk/variants/session.go @@ -6,7 +6,7 @@ package variants import ( "context" - "fmt" + "reflect" "sync" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -42,110 +42,27 @@ type backendSession struct { mcpMethodHandler mcp.MethodHandler } -func (s *backendSession) ListTools(ctx context.Context, p *mcp.ListToolsParams, extra *mcp.RequestExtra) (*mcp.ListToolsResult, error) { - result, err := s.mcpMethodHandler(ctx, "tools/list", &mcp.ListToolsRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListToolsResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for tools/list", result) - } - return r, nil -} - -func (s *backendSession) ListResources(ctx context.Context, p *mcp.ListResourcesParams, extra *mcp.RequestExtra) (*mcp.ListResourcesResult, error) { - result, err := s.mcpMethodHandler(ctx, "resources/list", &mcp.ListResourcesRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListResourcesResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for resources/list", result) - } - return r, nil -} - -func (s *backendSession) ListPrompts(ctx context.Context, p *mcp.ListPromptsParams, extra *mcp.RequestExtra) (*mcp.ListPromptsResult, error) { - result, err := s.mcpMethodHandler(ctx, "prompts/list", &mcp.ListPromptsRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListPromptsResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for prompts/list", result) - } - return r, nil -} - -func (s *backendSession) ListResourceTemplates(ctx context.Context, p *mcp.ListResourceTemplatesParams, extra *mcp.RequestExtra) (*mcp.ListResourceTemplatesResult, error) { - result, err := s.mcpMethodHandler(ctx, "resources/templates/list", &mcp.ListResourceTemplatesRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ListResourceTemplatesResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for resources/templates/list", result) - } - return r, nil -} - -func (s *backendSession) CallTool(ctx context.Context, p *mcp.CallToolParamsRaw, extra *mcp.RequestExtra) (*mcp.CallToolResult, error) { - result, err := s.mcpMethodHandler(ctx, "tools/call", &mcp.CallToolRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.CallToolResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for tools/call", result) +// handleReceive invokes mcpMethodHandler for any MCP method by modifying the request's +// Session field to point to the inner server session. This replaces the explicit +// per-method functions (ListTools, CallTool, etc.) with a single generic handler. +// +// The dispatcher is responsible for modifying params (metadata injection, +// cursor unwrapping) before calling this method. +func (s *backendSession) handleReceive(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + // Use reflection to modify the Session field in place. We can't wrap the + // request because the SDK's receiving handler does type assertions on the + // concrete request type (e.g., *mcp.ServerRequest[*mcp.CallToolParamsRaw]). + reqVal := reflect.ValueOf(req) + if reqVal.Kind() == reflect.Ptr { + reqVal = reqVal.Elem() } - return r, nil -} -func (s *backendSession) ReadResource(ctx context.Context, p *mcp.ReadResourceParams, extra *mcp.RequestExtra) (*mcp.ReadResourceResult, error) { - result, err := s.mcpMethodHandler(ctx, "resources/read", &mcp.ReadResourceRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.ReadResourceResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for resources/read", result) + sessionField := reqVal.FieldByName("Session") + if sessionField.IsValid() && sessionField.CanSet() { + sessionField.Set(reflect.ValueOf(s.serverSession)) } - return r, nil -} -func (s *backendSession) GetPrompt(ctx context.Context, p *mcp.GetPromptParams, extra *mcp.RequestExtra) (*mcp.GetPromptResult, error) { - result, err := s.mcpMethodHandler(ctx, "prompts/get", &mcp.GetPromptRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.GetPromptResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for prompts/get", result) - } - return r, nil -} - -func (s *backendSession) Subscribe(ctx context.Context, p *mcp.SubscribeParams, extra *mcp.RequestExtra) error { - _, err := s.mcpMethodHandler(ctx, "resources/subscribe", &mcp.SubscribeRequest{Session: s.serverSession, Params: p, Extra: extra}) - return err -} - -func (s *backendSession) Unsubscribe(ctx context.Context, p *mcp.UnsubscribeParams, extra *mcp.RequestExtra) error { - _, err := s.mcpMethodHandler(ctx, "resources/unsubscribe", &mcp.UnsubscribeRequest{Session: s.serverSession, Params: p, Extra: extra}) - return err -} - -func (s *backendSession) Complete(ctx context.Context, p *mcp.CompleteParams, extra *mcp.RequestExtra) (*mcp.CompleteResult, error) { - result, err := s.mcpMethodHandler(ctx, "completion/complete", &mcp.CompleteRequest{Session: s.serverSession, Params: p, Extra: extra}) - if err != nil { - return nil, err - } - r, ok := result.(*mcp.CompleteResult) - if !ok && result != nil { - return nil, fmt.Errorf("unexpected result type %T for completion/complete", result) - } - return r, nil + return s.mcpMethodHandler(ctx, method, req) } // sessionState holds all per-session state for one front client. From d42e4842050e84b0da50715aced5eebcc0ee46c8 Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Mon, 23 Mar 2026 12:59:08 +0000 Subject: [PATCH 2/7] Fix cursor setting for handle list --- go/sdk/variants/dispatch.go | 70 +++++++++++++++++++++++++------------ 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/go/sdk/variants/dispatch.go b/go/sdk/variants/dispatch.go index 3bc896f..a2df34b 100644 --- a/go/sdk/variants/dispatch.go +++ b/go/sdk/variants/dispatch.go @@ -159,21 +159,9 @@ func enrichError(err error, variantID string) error { // List methods // --------------------------------------------------------------------------- -// cursorParams is implemented by all list params types that support pagination. -type cursorParams interface { - mcp.Params - GetCursor() string - SetCursor(string) -} - -// cursorResult is implemented by all list result types that support pagination. -type cursorResult interface { - GetNextCursor() string - SetNextCursor(string) -} - // handleList handles list methods using the generic backend session call method. // Implements cursor scoping per SEP-2053: unwraps incoming cursors and wraps outgoing cursors. +// Uses type switches for type-safe cursor field access. func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { conn, err := d.getConnection(ctx, req) if err != nil { @@ -188,14 +176,39 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ if !isParamsNil(params) { injectVariantMeta(params, variantID) - // Handle cursor unwrapping if params support it - if cp, ok := params.(cursorParams); ok { - if cursor := cp.GetCursor(); cursor != "" { - innerCursor, err := unwrapCursor(cursor, variantID) + // Handle cursor unwrapping using type switch for type-safe field access + switch p := params.(type) { + case *mcp.ListToolsParams: + if p.Cursor != "" { + innerCursor, err := unwrapCursor(p.Cursor, variantID) + if err != nil { + return nil, err + } + p.Cursor = innerCursor + } + case *mcp.ListResourcesParams: + if p.Cursor != "" { + innerCursor, err := unwrapCursor(p.Cursor, variantID) if err != nil { return nil, err } - cp.SetCursor(innerCursor) + p.Cursor = innerCursor + } + case *mcp.ListPromptsParams: + if p.Cursor != "" { + innerCursor, err := unwrapCursor(p.Cursor, variantID) + if err != nil { + return nil, err + } + p.Cursor = innerCursor + } + case *mcp.ListResourceTemplatesParams: + if p.Cursor != "" { + innerCursor, err := unwrapCursor(p.Cursor, variantID) + if err != nil { + return nil, err + } + p.Cursor = innerCursor } } } @@ -206,10 +219,23 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ return nil, enrichError(err, variantID) } - // Handle cursor wrapping in result - if cr, ok := result.(cursorResult); ok { - if cursor := cr.GetNextCursor(); cursor != "" { - cr.SetNextCursor(wrapCursor(cursor, variantID)) + // Handle cursor wrapping in result using type switch for type-safe field access + switch r := result.(type) { + case *mcp.ListToolsResult: + if r.NextCursor != "" { + r.NextCursor = wrapCursor(r.NextCursor, variantID) + } + case *mcp.ListResourcesResult: + if r.NextCursor != "" { + r.NextCursor = wrapCursor(r.NextCursor, variantID) + } + case *mcp.ListPromptsResult: + if r.NextCursor != "" { + r.NextCursor = wrapCursor(r.NextCursor, variantID) + } + case *mcp.ListResourceTemplatesResult: + if r.NextCursor != "" { + r.NextCursor = wrapCursor(r.NextCursor, variantID) } } From 815ac172b84a3cce61599a184e8bb70ecfff4dbd Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Tue, 24 Mar 2026 09:42:41 +0000 Subject: [PATCH 3/7] Use reflect to acsess Cursor and NextCursor fields --- go/sdk/variants/dispatch.go | 59 +++++-------------------------------- 1 file changed, 7 insertions(+), 52 deletions(-) diff --git a/go/sdk/variants/dispatch.go b/go/sdk/variants/dispatch.go index a2df34b..8145246 100644 --- a/go/sdk/variants/dispatch.go +++ b/go/sdk/variants/dispatch.go @@ -161,7 +161,6 @@ func enrichError(err error, variantID string) error { // handleList handles list methods using the generic backend session call method. // Implements cursor scoping per SEP-2053: unwraps incoming cursors and wraps outgoing cursors. -// Uses type switches for type-safe cursor field access. func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { conn, err := d.getConnection(ctx, req) if err != nil { @@ -176,40 +175,12 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ if !isParamsNil(params) { injectVariantMeta(params, variantID) - // Handle cursor unwrapping using type switch for type-safe field access - switch p := params.(type) { - case *mcp.ListToolsParams: - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - case *mcp.ListResourcesParams: - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - case *mcp.ListPromptsParams: - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor - } - case *mcp.ListResourceTemplatesParams: - if p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err - } - p.Cursor = innerCursor + if f := reflect.ValueOf(params).Elem().FieldByName("Cursor"); f.IsValid() && f.String() != "" { + innerCursor, err := unwrapCursor(f.String(), variantID) + if err != nil { + return nil, err } + f.SetString(innerCursor) } } @@ -219,24 +190,8 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ return nil, enrichError(err, variantID) } - // Handle cursor wrapping in result using type switch for type-safe field access - switch r := result.(type) { - case *mcp.ListToolsResult: - if r.NextCursor != "" { - r.NextCursor = wrapCursor(r.NextCursor, variantID) - } - case *mcp.ListResourcesResult: - if r.NextCursor != "" { - r.NextCursor = wrapCursor(r.NextCursor, variantID) - } - case *mcp.ListPromptsResult: - if r.NextCursor != "" { - r.NextCursor = wrapCursor(r.NextCursor, variantID) - } - case *mcp.ListResourceTemplatesResult: - if r.NextCursor != "" { - r.NextCursor = wrapCursor(r.NextCursor, variantID) - } + if f := reflect.ValueOf(result).Elem().FieldByName("NextCursor"); f.IsValid() && f.String() != "" { + f.SetString(wrapCursor(f.String(), variantID)) } return result, nil From 4fc28738faa2ae66267606aaff5709e29018a7dc Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Fri, 24 Apr 2026 22:55:57 +0100 Subject: [PATCH 4/7] guard nil result in handleList --- go/sdk/variants/dispatch.go | 20 +++++----- go/sdk/variants/dispatch_test.go | 65 ++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 9 deletions(-) create mode 100644 go/sdk/variants/dispatch_test.go diff --git a/go/sdk/variants/dispatch.go b/go/sdk/variants/dispatch.go index 8145246..f3d80d0 100644 --- a/go/sdk/variants/dispatch.go +++ b/go/sdk/variants/dispatch.go @@ -58,14 +58,13 @@ func (d *dispatcher) createInvalidVariantError(ctx context.Context, requestedVar } } -// isParamsNil checks if params is nil or a typed-nil (a nil pointer wrapped in an interface). -// The SDK can produce typed-nil params for requests with no parameters. -func isParamsNil(params mcp.Params) bool { - if params == nil { +// isNilInterface checks if v is nil or a typed-nil (a nil pointer wrapped in an interface). +func isNilInterface(v any) bool { + if v == nil { return true } - v := reflect.ValueOf(params) - return v.Kind() == reflect.Ptr && v.IsNil() + rv := reflect.ValueOf(v) + return rv.Kind() == reflect.Ptr && rv.IsNil() } // variantIDFromMeta extracts the variant ID from the request's _meta field. @@ -74,7 +73,7 @@ func isParamsNil(params mcp.Params) bool { // which the SDK can produce for requests with no parameters. func variantIDFromMeta(req mcp.Request) string { params := req.GetParams() - if isParamsNil(params) { + if isNilInterface(params) { return "" } meta := params.GetMeta() @@ -172,7 +171,7 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ params := req.GetParams() // Inject variant metadata and handle cursor unwrapping (guard against typed-nil params) - if !isParamsNil(params) { + if !isNilInterface(params) { injectVariantMeta(params, variantID) if f := reflect.ValueOf(params).Elem().FieldByName("Cursor"); f.IsValid() && f.String() != "" { @@ -189,6 +188,9 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ if err != nil { return nil, enrichError(err, variantID) } + if isNilInterface(result) { + return nil, nil + } if f := reflect.ValueOf(result).Elem().FieldByName("NextCursor"); f.IsValid() && f.String() != "" { f.SetString(wrapCursor(f.String(), variantID)) @@ -215,7 +217,7 @@ func (d *dispatcher) handleReceiveRedirect(ctx context.Context, method string, r params := req.GetParams() // Inject variant metadata (guard against typed-nil params) - if !isParamsNil(params) { + if !isNilInterface(params) { injectVariantMeta(params, variantID) } diff --git a/go/sdk/variants/dispatch_test.go b/go/sdk/variants/dispatch_test.go new file mode 100644 index 0000000..a3813d5 --- /dev/null +++ b/go/sdk/variants/dispatch_test.go @@ -0,0 +1,65 @@ +// Copyright 2025 The MCP Variants Authors. All rights reserved. +// Use of this source code is governed by a Apache-2.0 +// license that can be found in the LICENSE file. + +package variants + +import ( + "context" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func TestHandleList_NilResult(t *testing.T) { + const variantID = "test-variant" + + tests := []struct { + name string + handler func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) + }{ + { + name: "untyped nil", + handler: func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return nil, nil + }, + }, + { + name: "typed nil", + handler: func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return (*mcp.ListToolsResult)(nil), nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &dispatcher{ + server: NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}). + WithVariant(ServerVariant{ID: variantID, Status: Stable}, mcp.NewServer(&mcp.Implementation{Name: "inner", Version: "v0.0.1"}, nil), 0), + connections: map[string]*innerConnection{ + variantID: { + backendSession: &backendSession{ + variantID: variantID, + mcpMethodHandler: tt.handler, + }, + }, + }, + } + + req := &mcp.ServerRequest[*mcp.ListToolsParams]{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + }, + } + + result, err := d.handleList(context.Background(), "tools/list", req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != nil { + t.Fatalf("expected nil result, got %v", result) + } + }) + } +} From dc7f6fd7db2bff7cc581cfc65ce1eb0922dd80fd Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Tue, 28 Apr 2026 09:14:04 +0100 Subject: [PATCH 5/7] Dispatch will pass unknown jsonrpc methods through to next handler Fix nil handling for params and result Add unit tests for dispatcher --- go/sdk/variants/dispatch.go | 45 ++- go/sdk/variants/dispatch_test.go | 656 ++++++++++++++++++++++++++++++- go/sdk/variants/session.go | 6 +- 3 files changed, 684 insertions(+), 23 deletions(-) diff --git a/go/sdk/variants/dispatch.go b/go/sdk/variants/dispatch.go index f3d80d0..a743c29 100644 --- a/go/sdk/variants/dispatch.go +++ b/go/sdk/variants/dispatch.go @@ -29,8 +29,12 @@ func (d *dispatcher) handle(ctx context.Context, method string, req mcp.Request, switch method { case "tools/list", "resources/list", "prompts/list", "resources/templates/list": return d.handleList(ctx, method, req) + case "tools/call", "resources/read", "prompts/get", + "resources/subscribe", "resources/unsubscribe", + "completion/complete": + return d.handleDirect(ctx, method, req) default: - return d.handleReceiveRedirect(ctx, method, req) + return next(ctx, method, req) } } @@ -67,6 +71,32 @@ func isNilInterface(v any) bool { return rv.Kind() == reflect.Ptr && rv.IsNil() } +// The reflect-based field access (cursor unwrap/wrap, metadata injection) calls +// Elem() to dereference pointers, which panics on non-pointer types. Even without +// the panic, mutations on a value type would not propagate back to the original +// request/result. The current SDK satisfies this: all Params and Result types use +// pointer receivers for their interface marker methods (isParams/isResult), so only +// pointer types can implement the interfaces. These checks guard against future +// SDK changes. +var ( + errParamsNotPointer = errors.New("variants: expected pointer type for Params, got value type") + errResultNotPointer = errors.New("variants: expected pointer type for Result, got value type") +) + +// Compile-time assertions: ensure list param/result types have the Cursor and +// NextCursor fields that handleList accesses via reflection. +var ( + _ = mcp.ListToolsParams{}.Cursor + _ = mcp.ListResourcesParams{}.Cursor + _ = mcp.ListPromptsParams{}.Cursor + _ = mcp.ListResourceTemplatesParams{}.Cursor + + _ = mcp.ListToolsResult{}.NextCursor + _ = mcp.ListResourcesResult{}.NextCursor + _ = mcp.ListPromptsResult{}.NextCursor + _ = mcp.ListResourceTemplatesResult{}.NextCursor +) + // variantIDFromMeta extracts the variant ID from the request's _meta field. // Returns empty string if no variant is specified. Guards against typed-nil // params (e.g. (*ListToolsParams)(nil) wrapped in the mcp.Params interface) @@ -172,6 +202,9 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ // Inject variant metadata and handle cursor unwrapping (guard against typed-nil params) if !isNilInterface(params) { + if reflect.ValueOf(params).Kind() != reflect.Ptr { + return nil, errParamsNotPointer + } injectVariantMeta(params, variantID) if f := reflect.ValueOf(params).Elem().FieldByName("Cursor"); f.IsValid() && f.String() != "" { @@ -192,6 +225,9 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ return nil, nil } + if reflect.ValueOf(result).Kind() != reflect.Ptr { + return nil, errResultNotPointer + } if f := reflect.ValueOf(result).Elem().FieldByName("NextCursor"); f.IsValid() && f.String() != "" { f.SetString(wrapCursor(f.String(), variantID)) } @@ -203,10 +239,10 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ // Simple methods (no pagination) // --------------------------------------------------------------------------- -// handleReceiveRedirect handles all simple methods (call, subscribe, unsubscribe, completion) +// handleDirect handles all simple methods (call, subscribe, unsubscribe, completion) // that don't require special cursor handling. This consolidates what were previously // separate handlers for handleCall, handleSubscribe, handleUnsubscribe, and handleCompletion. -func (d *dispatcher) handleReceiveRedirect(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { +func (d *dispatcher) handleDirect(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { conn, err := d.getConnection(ctx, req) if err != nil { return nil, err @@ -218,6 +254,9 @@ func (d *dispatcher) handleReceiveRedirect(ctx context.Context, method string, r // Inject variant metadata (guard against typed-nil params) if !isNilInterface(params) { + if reflect.ValueOf(params).Kind() != reflect.Ptr { + return nil, errParamsNotPointer + } injectVariantMeta(params, variantID) } diff --git a/go/sdk/variants/dispatch_test.go b/go/sdk/variants/dispatch_test.go index a3813d5..15863f1 100644 --- a/go/sdk/variants/dispatch_test.go +++ b/go/sdk/variants/dispatch_test.go @@ -6,11 +6,61 @@ package variants import ( "context" + "errors" + "reflect" "testing" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// --------------------------------------------------------------------------- +// isNilInterface +// --------------------------------------------------------------------------- + +func TestIsNilInterface(t *testing.T) { + tests := []struct { + name string + val any + want bool + }{ + {"untyped nil", nil, true}, + {"typed nil pointer", (*mcp.ListToolsParams)(nil), true}, + {"non-nil pointer", &mcp.ListToolsParams{}, false}, + {"non-pointer value", 42, false}, + {"string value", "hello", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, isNilInterface(tt.val)) + }) + } +} + +// --------------------------------------------------------------------------- +// handleList +// --------------------------------------------------------------------------- + +// newTestDispatcher builds a dispatcher with a single variant backed by a +// custom mcpMethodHandler. The handler receives the request exactly as the +// dispatcher sends it, so tests can inspect params/session mutations. +func newTestDispatcher(variantID string, handler mcp.MethodHandler) *dispatcher { + return &dispatcher{ + server: NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}). + WithVariant(ServerVariant{ID: variantID, Status: Stable}, mcp.NewServer(&mcp.Implementation{Name: "inner", Version: "v0.0.1"}, nil), 0), + connections: map[string]*innerConnection{ + variantID: { + backendSession: &backendSession{ + variantID: variantID, + mcpMethodHandler: handler, + }, + }, + }, + } +} + func TestHandleList_NilResult(t *testing.T) { const variantID = "test-variant" @@ -34,31 +84,601 @@ func TestHandleList_NilResult(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - d := &dispatcher{ - server: NewServer(&mcp.Implementation{Name: "test", Version: "v0.0.1"}). - WithVariant(ServerVariant{ID: variantID, Status: Stable}, mcp.NewServer(&mcp.Implementation{Name: "inner", Version: "v0.0.1"}, nil), 0), - connections: map[string]*innerConnection{ - variantID: { - backendSession: &backendSession{ - variantID: variantID, - mcpMethodHandler: tt.handler, - }, - }, - }, - } - - req := &mcp.ServerRequest[*mcp.ListToolsParams]{ + d := newTestDispatcher(variantID, tt.handler) + req := &mcp.ListToolsRequest{ Params: &mcp.ListToolsParams{ Meta: mcp.Meta{metaKeyVariant: variantID}, }, } result, err := d.handleList(context.Background(), "tools/list", req) - if err != nil { - t.Fatalf("unexpected error: %v", err) + require.NoError(t, err) + assert.Nil(t, result) + }) + } +} + +func TestHandleList_CursorRoundTrip(t *testing.T) { + const variantID = "v1" + const innerCursor = "page-2-token" + + tests := []struct { + name string + method string + req func(wrappedCursor string) mcp.Request + result func() mcp.Result // result returned by backend with NextCursor set + }{ + { + name: "tools/list", + method: "tools/list", + req: func(c string) mcp.Request { + return &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: c, + }, + } + }, + result: func() mcp.Result { + return &mcp.ListToolsResult{NextCursor: "next-inner"} + }, + }, + { + name: "resources/list", + method: "resources/list", + req: func(c string) mcp.Request { + return &mcp.ListResourcesRequest{ + Params: &mcp.ListResourcesParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: c, + }, + } + }, + result: func() mcp.Result { + return &mcp.ListResourcesResult{NextCursor: "next-inner"} + }, + }, + { + name: "prompts/list", + method: "prompts/list", + req: func(c string) mcp.Request { + return &mcp.ListPromptsRequest{ + Params: &mcp.ListPromptsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: c, + }, + } + }, + result: func() mcp.Result { + return &mcp.ListPromptsResult{NextCursor: "next-inner"} + }, + }, + { + name: "resources/templates/list", + method: "resources/templates/list", + req: func(c string) mcp.Request { + return &mcp.ListResourceTemplatesRequest{ + Params: &mcp.ListResourceTemplatesParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: c, + }, + } + }, + result: func() mcp.Result { + return &mcp.ListResourceTemplatesResult{NextCursor: "next-inner"} + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wrappedCursor := wrapCursor(innerCursor, variantID) + var receivedCursor string + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + // Capture what cursor the backend actually receives. + params := req.GetParams() + f := reflect.ValueOf(params).Elem().FieldByName("Cursor") + receivedCursor = f.String() + return tt.result(), nil + }) + + result, err := d.handleList(context.Background(), tt.method, tt.req(wrappedCursor)) + require.NoError(t, err) + + // Backend should have received the inner (unwrapped) cursor. + assert.Equal(t, innerCursor, receivedCursor, "cursor should be unwrapped before reaching backend") + + // Result's NextCursor should now be wrapped. + nextCursor := reflect.ValueOf(result).Elem().FieldByName("NextCursor").String() + unwrapped, err := unwrapCursor(nextCursor, variantID) + require.NoError(t, err) + assert.Equal(t, "next-inner", unwrapped, "NextCursor should be wrapped with variant ID") + }) + } +} + +func TestHandleList_NoCursor(t *testing.T) { + const variantID = "v1" + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return &mcp.ListToolsResult{}, nil + }) + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + }, + } + + result, err := d.handleList(context.Background(), "tools/list", req) + require.NoError(t, err) + + // No NextCursor should be set when backend returns empty cursor. + r := result.(*mcp.ListToolsResult) + assert.Empty(t, r.NextCursor) +} + +func TestHandleList_InvalidCursor(t *testing.T) { + const variantID = "v1" + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + t.Fatal("backend should not be called with invalid cursor") + return nil, nil + }) + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: "not-valid-base64!@#$", + }, + } + + _, err := d.handleList(context.Background(), "tools/list", req) + require.Error(t, err) + + var jErr *jsonrpc.Error + require.True(t, errors.As(err, &jErr)) + assert.EqualValues(t, jsonrpc.CodeInvalidParams, jErr.Code) +} + +func TestHandleList_CrossVariantCursor(t *testing.T) { + const variantID = "v1" + otherVariantCursor := wrapCursor("page2", "other-variant") + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + t.Fatal("backend should not be called with cross-variant cursor") + return nil, nil + }) + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Cursor: otherVariantCursor, + }, + } + + _, err := d.handleList(context.Background(), "tools/list", req) + require.Error(t, err) + + var jErr *jsonrpc.Error + require.True(t, errors.As(err, &jErr)) + assert.EqualValues(t, jsonrpc.CodeInvalidParams, jErr.Code) + assert.Contains(t, jErr.Message, "Cursor invalid for requested variant") +} + +func TestHandleList_ErrorEnrichment(t *testing.T) { + const variantID = "v1" + + backendErr := &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "unknown tool", + } + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return nil, backendErr + }) + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + }, + } + + _, err := d.handleList(context.Background(), "tools/list", req) + require.Error(t, err) + + var jErr *jsonrpc.Error + require.True(t, errors.As(err, &jErr)) + assert.Contains(t, string(jErr.Data), `"activeVariant"`) + assert.Contains(t, string(jErr.Data), variantID) +} + +func TestHandleList_NilParams(t *testing.T) { + const variantID = "v1" + called := false + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + called = true + return &mcp.ListToolsResult{}, nil + }) + + // Typed-nil params — the SDK produces these for parameterless requests. + req := &mcp.ListToolsRequest{ + Params: (*mcp.ListToolsParams)(nil), + } + + _, err := d.handleList(context.Background(), "tools/list", req) + require.NoError(t, err) + assert.True(t, called, "backend should still be called with nil params") +} + +// --------------------------------------------------------------------------- +// handleDirect +// --------------------------------------------------------------------------- + +func TestHandleDirect_MethodRouting(t *testing.T) { + const variantID = "v1" + + tests := []struct { + name string + method string + req mcp.Request + }{ + { + name: "tools/call", + method: "tools/call", + req: &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Name: "my-tool", + }, + }, + }, + { + name: "resources/read", + method: "resources/read", + req: &mcp.ReadResourceRequest{ + Params: &mcp.ReadResourceParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + URI: "file:///test", + }, + }, + }, + { + name: "prompts/get", + method: "prompts/get", + req: &mcp.GetPromptRequest{ + Params: &mcp.GetPromptParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Name: "my-prompt", + }, + }, + }, + { + name: "resources/subscribe", + method: "resources/subscribe", + req: &mcp.SubscribeRequest{ + Params: &mcp.SubscribeParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + URI: "file:///watch", + }, + }, + }, + { + name: "resources/unsubscribe", + method: "resources/unsubscribe", + req: &mcp.UnsubscribeRequest{ + Params: &mcp.UnsubscribeParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + URI: "file:///watch", + }, + }, + }, + { + name: "completion/complete", + method: "completion/complete", + req: &mcp.CompleteRequest{ + Params: &mcp.CompleteParams{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var receivedMethod string + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + receivedMethod = method + return nil, nil + }) + + _, err := d.handleDirect(context.Background(), tt.method, tt.req) + require.NoError(t, err) + assert.Equal(t, tt.method, receivedMethod, "backend should receive the correct method") + }) + } +} + +func TestHandleDirect_MetaInjection(t *testing.T) { + const variantID = "v1" + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + meta := req.GetParams().GetMeta() + if meta[metaKeyVariant] != variantID { + return nil, errors.New("variant meta not injected") + } + return &mcp.CallToolResult{}, nil + }) + + // Start without variant meta set — handleDirect should inject it. + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Name: "my-tool", + }, + } + + _, err := d.handleDirect(context.Background(), "tools/call", req) + require.NoError(t, err) +} + +func TestHandleDirect_ErrorEnrichment(t *testing.T) { + const variantID = "v1" + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "tool not found", + } + }) + + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Name: "nonexistent", + }, + } + + _, err := d.handleDirect(context.Background(), "tools/call", req) + require.Error(t, err) + + var jErr *jsonrpc.Error + require.True(t, errors.As(err, &jErr)) + assert.Contains(t, string(jErr.Data), `"activeVariant"`) + assert.Contains(t, string(jErr.Data), variantID) +} + +func TestHandleDirect_NonEnrichableError(t *testing.T) { + const variantID = "v1" + plainErr := errors.New("internal error") + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + return nil, plainErr + }) + + req := &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Meta: mcp.Meta{metaKeyVariant: variantID}, + Name: "my-tool", + }, + } + + _, err := d.handleDirect(context.Background(), "tools/call", req) + assert.Equal(t, plainErr, err, "non-jsonrpc errors should pass through unmodified") +} + +func TestHandleDirect_NilParams(t *testing.T) { + const variantID = "v1" + called := false + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + called = true + return nil, nil + }) + + req := &mcp.SubscribeRequest{ + Params: (*mcp.SubscribeParams)(nil), + } + + _, err := d.handleDirect(context.Background(), "resources/subscribe", req) + require.NoError(t, err) + assert.True(t, called, "backend should still be called when params are typed-nil") +} + +// --------------------------------------------------------------------------- +// handle (top-level router) +// --------------------------------------------------------------------------- + +func TestHandle_UnknownMethodPassthrough(t *testing.T) { + const variantID = "v1" + + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + t.Fatal("dispatcher handler should not be called for unknown methods") + return nil, nil + }) + + nextCalled := false + next := func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + nextCalled = true + return nil, nil + } + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{}, + } + + _, err := d.handle(context.Background(), "custom/method", req, next) + require.NoError(t, err) + assert.True(t, nextCalled, "unknown methods should fall through to next") +} + +func TestHandle_RoutesToCorrectHandler(t *testing.T) { + const variantID = "v1" + + tests := []struct { + name string + method string + req mcp.Request + }{ + {"list routes to handleList", "tools/list", &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{Meta: mcp.Meta{metaKeyVariant: variantID}}, + }}, + {"call routes to handleDirect", "tools/call", &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{Meta: mcp.Meta{metaKeyVariant: variantID}, Name: "t"}, + }}, + {"subscribe routes to handleDirect", "resources/subscribe", &mcp.SubscribeRequest{ + Params: &mcp.SubscribeParams{Meta: mcp.Meta{metaKeyVariant: variantID}, URI: "u"}, + }}, + {"unsubscribe routes to handleDirect", "resources/unsubscribe", &mcp.UnsubscribeRequest{ + Params: &mcp.UnsubscribeParams{Meta: mcp.Meta{metaKeyVariant: variantID}, URI: "u"}, + }}, + {"complete routes to handleDirect", "completion/complete", &mcp.CompleteRequest{ + Params: &mcp.CompleteParams{Meta: mcp.Meta{metaKeyVariant: variantID}}, + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var receivedMethod string + d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + receivedMethod = method + if method == "tools/list" { + return &mcp.ListToolsResult{}, nil + } + return nil, nil + }) + + next := func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + t.Fatalf("next should not be called for method %s", method) + return nil, nil } - if result != nil { - t.Fatalf("expected nil result, got %v", result) + + _, err := d.handle(context.Background(), tt.method, tt.req, next) + require.NoError(t, err) + assert.Equal(t, tt.method, receivedMethod) + }) + } +} + +// --------------------------------------------------------------------------- +// handleReceive (session.go) +// --------------------------------------------------------------------------- + +func TestHandleReceive_SessionInjection(t *testing.T) { + originalSession := &mcp.ServerSession{} + targetSession := &mcp.ServerSession{} + + var receivedSession *mcp.ServerSession + bs := &backendSession{ + variantID: "v1", + serverSession: targetSession, + mcpMethodHandler: func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + receivedSession = req.GetSession().(*mcp.ServerSession) + return &mcp.ListToolsResult{}, nil + }, + } + + req := &mcp.ListToolsRequest{ + Session: originalSession, + Params: &mcp.ListToolsParams{}, + } + + _, err := bs.handleReceive(context.Background(), "tools/list", req) + require.NoError(t, err) + assert.Same(t, targetSession, receivedSession, "handleReceive should replace Session with inner server session") +} + +// --------------------------------------------------------------------------- +// getConnection +// --------------------------------------------------------------------------- + +func TestGetConnection_InvalidVariant(t *testing.T) { + d := newTestDispatcher("v1", nil) + + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{ + Meta: mcp.Meta{metaKeyVariant: "nonexistent"}, + }, + } + + _, err := d.getConnection(context.Background(), req) + require.Error(t, err) + + var jErr *jsonrpc.Error + require.True(t, errors.As(err, &jErr)) + assert.EqualValues(t, jsonrpc.CodeInvalidParams, jErr.Code) + assert.Contains(t, jErr.Message, "Invalid server variant") + assert.Contains(t, string(jErr.Data), "nonexistent") +} + +func TestGetConnection_DefaultVariant(t *testing.T) { + d := newTestDispatcher("v1", nil) + + // No variant in meta — should fall back to default (first-ranked). + req := &mcp.ListToolsRequest{ + Params: &mcp.ListToolsParams{}, + } + + conn, err := d.getConnection(context.Background(), req) + require.NoError(t, err) + assert.Equal(t, "v1", conn.backendSession.variantID) +} + +// --------------------------------------------------------------------------- +// enrichError +// --------------------------------------------------------------------------- + +func TestEnrichError(t *testing.T) { + tests := []struct { + name string + err error + wantEnrich bool + }{ + { + name: "InvalidParams enriched", + err: &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "bad param", + }, + wantEnrich: true, + }, + { + name: "MethodNotFound enriched", + err: &jsonrpc.Error{ + Code: jsonrpc.CodeMethodNotFound, + Message: "no method", + }, + wantEnrich: true, + }, + { + name: "InternalError not enriched", + err: &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "server error", + }, + wantEnrich: false, + }, + { + name: "plain error not enriched", + err: errors.New("some error"), + wantEnrich: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := enrichError(tt.err, "v1") + + var jErr *jsonrpc.Error + if errors.As(result, &jErr) && tt.wantEnrich { + assert.Contains(t, string(jErr.Data), `"activeVariant"`) + assert.Contains(t, string(jErr.Data), `"v1"`) + } else if !tt.wantEnrich { + if errors.As(result, &jErr) { + // Should NOT have activeVariant + assert.NotContains(t, string(jErr.Data), `"activeVariant"`) + } } }) } diff --git a/go/sdk/variants/session.go b/go/sdk/variants/session.go index 3e57368..996f0c9 100644 --- a/go/sdk/variants/session.go +++ b/go/sdk/variants/session.go @@ -6,6 +6,7 @@ package variants import ( "context" + "errors" "reflect" "sync" @@ -58,9 +59,10 @@ func (s *backendSession) handleReceive(ctx context.Context, method string, req m } sessionField := reqVal.FieldByName("Session") - if sessionField.IsValid() && sessionField.CanSet() { - sessionField.Set(reflect.ValueOf(s.serverSession)) + if !sessionField.IsValid() || !sessionField.CanSet() { + return nil, errors.New("variants: request type missing settable Session field") } + sessionField.Set(reflect.ValueOf(s.serverSession)) return s.mcpMethodHandler(ctx, method, req) } From 0fa83f15d6837ec2d1f29bb10c4bb5fd2f7d1ef0 Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Tue, 28 Apr 2026 09:48:27 +0100 Subject: [PATCH 6/7] Shallow copy the request instead of setting the Session in place --- go/sdk/variants/dispatch_test.go | 11 +++++------ go/sdk/variants/session.go | 20 +++++++++++++------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/go/sdk/variants/dispatch_test.go b/go/sdk/variants/dispatch_test.go index 15863f1..b8e0a57 100644 --- a/go/sdk/variants/dispatch_test.go +++ b/go/sdk/variants/dispatch_test.go @@ -403,24 +403,22 @@ func TestHandleDirect_MethodRouting(t *testing.T) { func TestHandleDirect_MetaInjection(t *testing.T) { const variantID = "v1" + var receivedMeta mcp.Meta d := newTestDispatcher(variantID, func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - meta := req.GetParams().GetMeta() - if meta[metaKeyVariant] != variantID { - return nil, errors.New("variant meta not injected") - } + receivedMeta = req.GetParams().GetMeta() return &mcp.CallToolResult{}, nil }) - // Start without variant meta set — handleDirect should inject it. + // Start without variant meta — handleDirect should inject it. req := &mcp.CallToolRequest{ Params: &mcp.CallToolParamsRaw{ - Meta: mcp.Meta{metaKeyVariant: variantID}, Name: "my-tool", }, } _, err := d.handleDirect(context.Background(), "tools/call", req) require.NoError(t, err) + assert.Equal(t, variantID, receivedMeta[metaKeyVariant], "variant meta should be injected into params") } func TestHandleDirect_ErrorEnrichment(t *testing.T) { @@ -587,6 +585,7 @@ func TestHandleReceive_SessionInjection(t *testing.T) { _, err := bs.handleReceive(context.Background(), "tools/list", req) require.NoError(t, err) assert.Same(t, targetSession, receivedSession, "handleReceive should replace Session with inner server session") + assert.Same(t, originalSession, req.Session, "handleReceive should not mutate the original request") } // --------------------------------------------------------------------------- diff --git a/go/sdk/variants/session.go b/go/sdk/variants/session.go index 996f0c9..d2f0449 100644 --- a/go/sdk/variants/session.go +++ b/go/sdk/variants/session.go @@ -50,21 +50,27 @@ type backendSession struct { // The dispatcher is responsible for modifying params (metadata injection, // cursor unwrapping) before calling this method. func (s *backendSession) handleReceive(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - // Use reflection to modify the Session field in place. We can't wrap the - // request because the SDK's receiving handler does type assertions on the - // concrete request type (e.g., *mcp.ServerRequest[*mcp.CallToolParamsRaw]). + // Shallow-copy the concrete request struct so we don't mutate the caller's object + // while replacing the Session field with our inner server session. + // We can't use a wrapper (like the sending side's sessionSwappedRequest) because + // the SDK's receiving handler does type assertions on the concrete request type. reqVal := reflect.ValueOf(req) - if reqVal.Kind() == reflect.Ptr { - reqVal = reqVal.Elem() + if reqVal.Kind() != reflect.Ptr { + return nil, errors.New("variants: expected pointer to request struct") } - sessionField := reqVal.FieldByName("Session") + // Allocate a new struct of the same concrete type and shallow-copy all fields. + copyPtr := reflect.New(reqVal.Elem().Type()) + copyPtr.Elem().Set(reqVal.Elem()) + + // Set Session on the copy to point to the inner server session. + sessionField := copyPtr.Elem().FieldByName("Session") if !sessionField.IsValid() || !sessionField.CanSet() { return nil, errors.New("variants: request type missing settable Session field") } sessionField.Set(reflect.ValueOf(s.serverSession)) - return s.mcpMethodHandler(ctx, method, req) + return s.mcpMethodHandler(ctx, method, copyPtr.Interface().(mcp.Request)) } // sessionState holds all per-session state for one front client. From c7d2544a82704d07cca4ae3a6f1a8ca31d5d77d1 Mon Sep 17 00:00:00 2001 From: Steven Chen Date: Tue, 28 Apr 2026 10:05:33 +0100 Subject: [PATCH 7/7] Update comments --- go/sdk/variants/dispatch.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/go/sdk/variants/dispatch.go b/go/sdk/variants/dispatch.go index a743c29..3a7b758 100644 --- a/go/sdk/variants/dispatch.go +++ b/go/sdk/variants/dispatch.go @@ -216,7 +216,6 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ } } - // Generic dispatch - pass entire request object result, err := backendSession.handleReceive(ctx, method, req) if err != nil { return nil, enrichError(err, variantID) @@ -240,8 +239,7 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ // --------------------------------------------------------------------------- // handleDirect handles all simple methods (call, subscribe, unsubscribe, completion) -// that don't require special cursor handling. This consolidates what were previously -// separate handlers for handleCall, handleSubscribe, handleUnsubscribe, and handleCompletion. +// that don't require special cursor handling. func (d *dispatcher) handleDirect(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { conn, err := d.getConnection(ctx, req) if err != nil { @@ -260,7 +258,6 @@ func (d *dispatcher) handleDirect(ctx context.Context, method string, req mcp.Re injectVariantMeta(params, variantID) } - // Generic dispatch - pass entire request object result, err := backendSession.handleReceive(ctx, method, req) if err != nil { return nil, enrichError(err, variantID)