From 20cc4cefb8d090f720512ac6ce5eaece13765fef Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Tue, 16 Jun 2026 19:05:00 +0200 Subject: [PATCH 1/4] feat(oauth): wire stdio OAuth 2.1 login into the server Connect the internal/oauth core library to the stdio MCP server so users can authenticate with an OAuth App or GitHub App client ID instead of a static personal access token. - BearerAuthTransport gains a TokenProvider that is consulted per request, letting the lazily-acquired, auto-refreshing OAuth token take effect without rebuilding the client. - createGitHubClients uses BearerAuthTransport (and skips go-github's WithAuthToken, which would pin a static token) when a TokenProvider is set. - RunStdioServer starts without a token and installs receiving middleware that runs the authorization flow on the first tool call, surfacing the auth URL or device code via elicitation (or a tool result as a fallback). - Tool filtering uses the requested OAuth scopes; the default supported set hides nothing, while a narrower --oauth-scopes both narrows the grant and filters tools accordingly. - A sessionPrompter adapts the MCP server session to oauth.Prompter, keeping the authorization URL off the model's context. - New stdio flags: --oauth-client-id/-client-secret/-scopes/-callback-port. This is stdio-only and deliberately does not touch MCP-HTTP auth. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cmd/github-mcp-server/main.go | 42 +++- internal/ghmcp/oauth.go | 128 ++++++++++++ internal/ghmcp/oauth_test.go | 329 ++++++++++++++++++++++++++++++ internal/ghmcp/server.go | 71 +++++-- pkg/github/server.go | 5 + pkg/http/transport/bearer.go | 12 +- pkg/http/transport/bearer_test.go | 164 +++++++++++++++ 7 files changed, 736 insertions(+), 15 deletions(-) create mode 100644 internal/ghmcp/oauth.go create mode 100644 internal/ghmcp/oauth_test.go create mode 100644 pkg/http/transport/bearer_test.go diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 604556692c..b329b5012d 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -8,8 +8,10 @@ import ( "time" "github.com/github/github-mcp-server/internal/ghmcp" + "github.com/github/github-mcp-server/internal/oauth" "github.com/github/github-mcp-server/pkg/github" ghhttp "github.com/github/github-mcp-server/pkg/http" + ghoauth "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -34,8 +36,9 @@ var ( Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`, RunE: func(_ *cobra.Command, _ []string) error { token := viper.GetString("personal_access_token") - if token == "" { - return errors.New("GITHUB_PERSONAL_ACCESS_TOKEN not set") + oauthClientID := viper.GetString("oauth-client-id") + if token == "" && oauthClientID == "" { + return errors.New("authentication required: set GITHUB_PERSONAL_ACCESS_TOKEN, or pass --oauth-client-id to log in via OAuth") } // If you're wondering why we're not using viper.GetStringSlice("toolsets"), @@ -95,6 +98,29 @@ var ( ExcludeTools: excludeTools, RepoAccessCacheTTL: &ttl, } + + // When no static token is provided, log in via OAuth using the given + // client. The requested scopes default to the full supported set + // (which filters out no tools); an explicit, narrower --oauth-scopes + // both narrows the grant and hides tools needing other scopes. + if token == "" { + scopes := ghoauth.SupportedScopes + if viper.IsSet("oauth-scopes") { + if err := viper.UnmarshalKey("oauth-scopes", &scopes); err != nil { + return fmt.Errorf("failed to unmarshal oauth-scopes: %w", err) + } + } + oauthConfig := oauth.NewGitHubConfig( + oauthClientID, + viper.GetString("oauth-client-secret"), + scopes, + viper.GetString("host"), + viper.GetInt("oauth-callback-port"), + ) + stdioServerConfig.OAuthManager = oauth.NewManager(oauthConfig, nil) + stdioServerConfig.OAuthScopes = scopes + } + return ghmcp.RunStdioServer(stdioServerConfig) }, } @@ -183,6 +209,14 @@ func init() { rootCmd.PersistentFlags().Bool("insiders", false, "Enable insiders features") rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") + // stdio-specific OAuth flags. Provide --oauth-client-id (instead of a token) + // to log in via the browser-based OAuth flow on first use. Works for both + // OAuth Apps and GitHub Apps. + stdioCmd.Flags().String("oauth-client-id", "", "OAuth App or GitHub App client ID, enabling interactive OAuth login when no token is set") + stdioCmd.Flags().String("oauth-client-secret", "", "OAuth client secret, if the app requires one (it is a public, non-confidential credential for distributed clients)") + stdioCmd.Flags().StringSlice("oauth-scopes", nil, "Comma-separated OAuth scopes to request; also filters tools to those scopes. Defaults to the full supported set") + stdioCmd.Flags().Int("oauth-callback-port", 0, "Fixed local port for the OAuth callback server. Defaults to a random port; set a fixed port when mapping it through Docker") + // HTTP-specific flags httpCmd.Flags().Int("port", 8082, "HTTP server port") httpCmd.Flags().String("listen-host", "", "Host the HTTP server binds to (e.g. 127.0.0.1). Empty binds to all interfaces.") @@ -205,6 +239,10 @@ func init() { _ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode")) _ = viper.BindPFlag("insiders", rootCmd.PersistentFlags().Lookup("insiders")) _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) + _ = viper.BindPFlag("oauth-client-id", stdioCmd.Flags().Lookup("oauth-client-id")) + _ = viper.BindPFlag("oauth-client-secret", stdioCmd.Flags().Lookup("oauth-client-secret")) + _ = viper.BindPFlag("oauth-scopes", stdioCmd.Flags().Lookup("oauth-scopes")) + _ = viper.BindPFlag("oauth-callback-port", stdioCmd.Flags().Lookup("oauth-callback-port")) _ = viper.BindPFlag("port", httpCmd.Flags().Lookup("port")) _ = viper.BindPFlag("listen-host", httpCmd.Flags().Lookup("listen-host")) _ = viper.BindPFlag("base-url", httpCmd.Flags().Lookup("base-url")) diff --git a/internal/ghmcp/oauth.go b/internal/ghmcp/oauth.go new file mode 100644 index 0000000000..6a1d388956 --- /dev/null +++ b/internal/ghmcp/oauth.go @@ -0,0 +1,128 @@ +package ghmcp + +import ( + "context" + "crypto/rand" + "fmt" + "log/slog" + + "github.com/github/github-mcp-server/internal/oauth" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// sessionPrompter adapts an MCP server session to oauth.Prompter, presenting +// authorization prompts to the user via elicitation. Keeping the prompt on the +// MCP control channel (rather than a tool result) keeps the authorization URL +// and any session-bound state out of the model's context. +type sessionPrompter struct { + session *mcp.ServerSession +} + +// elicitationCaps returns the client's declared elicitation capabilities, or nil +// if the client did not advertise any. +func (p *sessionPrompter) elicitationCaps() *mcp.ElicitationCapabilities { + params := p.session.InitializeParams() + if params == nil || params.Capabilities == nil { + return nil + } + return params.Capabilities.Elicitation +} + +// CanPromptURL reports whether the client supports URL-mode elicitation. +func (p *sessionPrompter) CanPromptURL() bool { + caps := p.elicitationCaps() + return caps != nil && caps.URL != nil +} + +// PromptURL presents the authorization URL via URL-mode elicitation and blocks +// until the user acknowledges, declines, or ctx is done. +func (p *sessionPrompter) PromptURL(ctx context.Context, prompt oauth.Prompt) error { + res, err := p.session.Elicit(ctx, &mcp.ElicitParams{ + Mode: "url", + Message: prompt.Message, + URL: prompt.URL, + ElicitationID: rand.Text(), + }) + if err != nil { + return err + } + if res.Action != "accept" { + return oauth.ErrPromptDeclined + } + return nil +} + +// CanPromptForm reports whether the client supports form-mode elicitation. The +// SDK treats a client that advertises neither form nor URL capabilities as +// supporting forms, for backward compatibility, so we mirror that here. +func (p *sessionPrompter) CanPromptForm() bool { + caps := p.elicitationCaps() + if caps == nil { + return false + } + return caps.Form != nil || caps.URL == nil +} + +// PromptForm presents a textual acknowledgement (used to display a device code +// when URL elicitation is unavailable) and blocks until the user responds. +func (p *sessionPrompter) PromptForm(ctx context.Context, prompt oauth.Prompt) error { + res, err := p.session.Elicit(ctx, &mcp.ElicitParams{ + Mode: "form", + Message: prompt.Message, + }) + if err != nil { + return err + } + if res.Action != "accept" { + return oauth.ErrPromptDeclined + } + return nil +} + +// oauthAuthenticator is the subset of *oauth.Manager that the middleware needs. +// Depending on the interface (rather than the concrete manager) lets the +// middleware be exercised with a deterministic fake, since driving the real +// manager to its branches would require standing up live GitHub flows. +type oauthAuthenticator interface { + HasToken() bool + Authenticate(ctx context.Context, prompter oauth.Prompter) (*oauth.Outcome, error) +} + +// createOAuthMiddleware returns receiving middleware that authorizes the session +// lazily, on the first tool call. Authorization is deferred until here (rather +// than at startup) because the prompts depend on an initialized session whose +// elicitation capabilities are known. +// +// When a token is already available the call proceeds untouched. Otherwise the +// flow runs: secure channels (browser, URL elicitation) block until the token +// arrives and then the call proceeds; the last-resort channel returns the +// instruction to the user as a tool result and asks them to retry. +func createOAuthMiddleware(mgr oauthAuthenticator, logger *slog.Logger) func(next mcp.MethodHandler) mcp.MethodHandler { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, request mcp.Request) (mcp.Result, error) { + if method != "tools/call" || mgr.HasToken() { + return next(ctx, method, request) + } + + callReq, ok := request.(*mcp.CallToolRequest) + if !ok { + return next(ctx, method, request) + } + + outcome, err := mgr.Authenticate(ctx, &sessionPrompter{session: callReq.Session}) + if err != nil { + return nil, fmt.Errorf("github authorization failed: %w", err) + } + if outcome != nil && outcome.UserAction != nil { + logger.Info("surfacing github authorization instructions to user") + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: outcome.UserAction.Message}}, + }, nil + } + return next(ctx, method, request) + } + } +} + +// ensure sessionPrompter satisfies the Prompter contract. +var _ oauth.Prompter = (*sessionPrompter)(nil) diff --git a/internal/ghmcp/oauth_test.go b/internal/ghmcp/oauth_test.go new file mode 100644 index 0000000000..4f370cf7bc --- /dev/null +++ b/internal/ghmcp/oauth_test.go @@ -0,0 +1,329 @@ +package ghmcp + +import ( + "context" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/github/github-mcp-server/internal/oauth" + "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +// probeToolName is the name of the throwaway tool the harness registers; its +// handler runs a probe closure against a sessionPrompter so the adapter can be +// exercised against a real, fully-negotiated server session from the client side. +const probeToolName = "probe" + +// runProbe stands up an in-memory MCP client/server pair, registers a tool whose +// handler runs probe against a sessionPrompter wrapping the live server session, +// and returns the text the probe produced. The client is configured with the +// given capabilities and elicitation handler so the adapter sees a real, +// fully-negotiated session rather than a hand-built fake. +func runProbe( + t *testing.T, + clientCaps *mcp.ClientCapabilities, + elicitationHandler func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error), + probe func(context.Context, *sessionPrompter) string, +) string { + t.Helper() + + server := mcp.NewServer(&mcp.Implementation{Name: "test-server", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: probeToolName}, func(ctx context.Context, req *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, any, error) { + text := probe(ctx, &sessionPrompter{session: req.Session}) + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: text}}}, nil, nil + }) + + st, ct := mcp.NewInMemoryTransports() + + ss, err := server.Connect(context.Background(), st, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = ss.Close() }) + + client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "v0.0.1"}, &mcp.ClientOptions{ + Capabilities: clientCaps, + ElicitationHandler: elicitationHandler, + }) + cs, err := client.Connect(context.Background(), ct, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = cs.Close() }) + + res, err := cs.CallTool(context.Background(), &mcp.CallToolParams{Name: probeToolName}) + require.NoError(t, err) + require.Len(t, res.Content, 1) + text, ok := res.Content[0].(*mcp.TextContent) + require.True(t, ok, "probe result should be text content") + return text.Text +} + +func TestSessionPrompterCapabilities(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + caps *mcp.ClientCapabilities + wantURL bool + wantForm bool + }{ + { + name: "no elicitation advertised", + caps: &mcp.ClientCapabilities{}, + wantURL: false, + wantForm: false, + }, + { + name: "url only", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{URL: &mcp.URLElicitationCapabilities{}}}, + wantURL: true, + wantForm: false, + }, + { + name: "form only", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{Form: &mcp.FormElicitationCapabilities{}}}, + wantURL: false, + wantForm: true, + }, + { + name: "url and form", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{URL: &mcp.URLElicitationCapabilities{}, Form: &mcp.FormElicitationCapabilities{}}}, + wantURL: true, + wantForm: true, + }, + { + name: "empty elicitation capability implies form for backward compatibility", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{}}, + wantURL: false, + wantForm: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := runProbe(t, tc.caps, nil, func(_ context.Context, p *sessionPrompter) string { + if p.CanPromptURL() { + if p.CanPromptForm() { + return "url+form" + } + return "url" + } + if p.CanPromptForm() { + return "form" + } + return "none" + }) + + want := "none" + switch { + case tc.wantURL && tc.wantForm: + want = "url+form" + case tc.wantURL: + want = "url" + case tc.wantForm: + want = "form" + } + assert.Equal(t, want, got) + }) + } +} + +func TestSessionPrompterPromptActions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + action string + wantDecline bool + }{ + {name: "accept", action: "accept", wantDecline: false}, + {name: "decline", action: "decline", wantDecline: true}, + {name: "cancel", action: "cancel", wantDecline: true}, + } + + caps := &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{ + URL: &mcp.URLElicitationCapabilities{}, + Form: &mcp.FormElicitationCapabilities{}, + }} + + for _, tc := range tests { + // URL and form modes share the accept/decline mapping; cover both. + for _, mode := range []string{"url", "form"} { + t.Run(tc.name+"/"+mode, func(t *testing.T) { + t.Parallel() + + handler := func(_ context.Context, _ *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return &mcp.ElicitResult{Action: tc.action}, nil + } + + got := runProbe(t, caps, handler, func(ctx context.Context, p *sessionPrompter) string { + var err error + if mode == "url" { + err = p.PromptURL(ctx, oauth.Prompt{Message: "msg", URL: "https://example.com/auth"}) + } else { + err = p.PromptForm(ctx, oauth.Prompt{Message: "msg"}) + } + if err == nil { + return "ok" + } + if err == oauth.ErrPromptDeclined { + return "declined" + } + return "error: " + err.Error() + }) + + if tc.wantDecline { + assert.Equal(t, "declined", got) + } else { + assert.Equal(t, "ok", got) + } + }) + } + } +} + +// fakeAuthenticator is a deterministic stand-in for *oauth.Manager that lets the +// middleware be tested at each branch without standing up live GitHub flows. +type fakeAuthenticator struct { + hasToken bool + outcome *oauth.Outcome + err error + authCalls int + lastPrompter oauth.Prompter +} + +func (f *fakeAuthenticator) HasToken() bool { return f.hasToken } + +func (f *fakeAuthenticator) Authenticate(_ context.Context, prompter oauth.Prompter) (*oauth.Outcome, error) { + f.authCalls++ + f.lastPrompter = prompter + return f.outcome, f.err +} + +func TestCreateOAuthMiddleware(t *testing.T) { + t.Parallel() + + const nextText = "handler-ran" + newNext := func(called *bool) mcp.MethodHandler { + return func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + *called = true + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: nextText}}}, nil + } + } + + t.Run("non tool call passes through without authenticating", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: false} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + _, err := mw(newNext(&called))(context.Background(), "initialize", &mcp.InitializeRequest{}) + require.NoError(t, err) + assert.True(t, called, "next should run") + assert.Zero(t, fake.authCalls, "authentication must not run for non tool calls") + }) + + t.Run("existing token short circuits authentication", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: true} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + _, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.NoError(t, err) + assert.True(t, called, "next should run") + assert.Zero(t, fake.authCalls, "authentication must be skipped when a token already exists") + }) + + t.Run("successful authentication proceeds to handler", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: false, outcome: nil, err: nil} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + res, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.NoError(t, err) + assert.Equal(t, 1, fake.authCalls) + assert.True(t, called, "next should run once authorized") + callRes, ok := res.(*mcp.CallToolResult) + require.True(t, ok) + require.Len(t, callRes.Content, 1) + assert.Equal(t, nextText, callRes.Content[0].(*mcp.TextContent).Text) + }) + + t.Run("pending user action is surfaced as a tool result", func(t *testing.T) { + t.Parallel() + const message = "Open https://example.com/auth to authorize, then retry." + fake := &fakeAuthenticator{hasToken: false, outcome: &oauth.Outcome{UserAction: &oauth.UserAction{Message: message}}} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + res, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.NoError(t, err) + assert.False(t, called, "next must not run while the user still needs to authorize") + callRes, ok := res.(*mcp.CallToolResult) + require.True(t, ok) + require.Len(t, callRes.Content, 1) + assert.Equal(t, message, callRes.Content[0].(*mcp.TextContent).Text) + }) + + t.Run("authentication error is returned", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: false, err: assert.AnError} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + _, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.Error(t, err) + assert.ErrorIs(t, err, assert.AnError) + assert.False(t, called, "next must not run when authentication fails") + }) +} + +// TestCreateGitHubClientsTokenProvider proves the OAuth wiring: when a +// TokenProvider is configured the REST client authenticates with the provider's +// current token on every request (and never pins a stale one), which is what the +// lazy, refreshing OAuth token depends on. +func TestCreateGitHubClientsTokenProvider(t *testing.T) { + t.Parallel() + + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get(headers.AuthorizationHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + current := "" + apiHost, err := utils.NewAPIHost(server.URL) + require.NoError(t, err) + + clients, err := createGitHubClients(github.MCPServerConfig{ + Version: "test", + TokenProvider: func() string { return current }, + }, apiHost) + require.NoError(t, err) + + do := func() { + resp, err := clients.rest.Client().Get(server.URL) + require.NoError(t, err) + defer resp.Body.Close() + } + + do() + assert.Equal(t, "Bearer", gotAuth, "no token before authorization") + + current = "oauth-token" + do() + assert.Equal(t, "Bearer oauth-token", gotAuth, "provider token used once available") + + current = "refreshed-token" + do() + assert.Equal(t, "Bearer refreshed-token", gotAuth, "refreshed provider token used") +} diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index a37c4d940d..2364b02688 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -12,6 +12,7 @@ import ( "syscall" "time" + "github.com/github/github-mcp-server/internal/oauth" "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/transport" @@ -61,16 +62,30 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv return nil, fmt.Errorf("failed to get Raw URL: %w", err) } - // Construct REST client + // Construct REST client. When a TokenProvider is configured (OAuth), we + // authenticate via BearerAuthTransport and skip go-github's WithAuthToken: + // the latter installs its own round tripper that would pin the static token + // and shadow the dynamic one. restUATransport := &transport.UserAgentTransport{ Transport: http.DefaultTransport, Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version), } - restClient, err := gogithub.NewClient( - gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), - gogithub.WithAuthToken(cfg.Token), - gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), - ) + var restClient *gogithub.Client + if cfg.TokenProvider != nil { + restClient, err = gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{Transport: &transport.BearerAuthTransport{ + Transport: restUATransport, + TokenProvider: cfg.TokenProvider, + }}), + gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), + ) + } else { + restClient, err = gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), + gogithub.WithAuthToken(cfg.Token), + gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), + ) + } if err != nil { return nil, fmt.Errorf("failed to create REST client: %w", err) } @@ -82,7 +97,8 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv Transport: &transport.GraphQLFeaturesTransport{ Transport: http.DefaultTransport, }, - Token: cfg.Token, + Token: cfg.Token, + TokenProvider: cfg.TokenProvider, }, } @@ -229,6 +245,18 @@ type StdioServerConfig struct { // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. RepoAccessCacheTTL *time.Duration + + // OAuthManager, when non-nil, enables OAuth 2.1 login for stdio mode. The + // server starts without a token and runs the authorization flow on the + // first tool call (see createOAuthMiddleware). It is mutually exclusive with + // a static Token. + OAuthManager *oauth.Manager + + // OAuthScopes are the scopes requested during OAuth login. They double as + // the scope set for tool filtering: tools requiring a scope outside this set + // are hidden. The default set is the full supported list, which hides + // nothing; an explicit, narrower list filters accordingly. + OAuthScopes []string } // RunStdioServer is not concurrent safe. @@ -255,11 +283,13 @@ func RunStdioServer(cfg StdioServerConfig) error { logger := slog.New(slogHandler) logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) - // Fetch token scopes for scope-based tool filtering (PAT tokens only) - // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. - // Fine-grained PATs and other token types don't support this, so we skip filtering. + // Determine the scope set used to filter tools. Classic PATs expose their + // granted scopes via the API; OAuth uses the requested scopes (the default + // set hides nothing, a narrower explicit set filters accordingly). Other + // token types don't advertise scopes, so filtering is skipped. var tokenScopes []string - if strings.HasPrefix(cfg.Token, "ghp_") { + switch { + case strings.HasPrefix(cfg.Token, "ghp_"): fetchedScopes, err := fetchTokenScopesForHost(ctx, cfg.Token, cfg.Host) if err != nil { logger.Warn("failed to fetch token scopes, continuing without scope filtering", "error", err) @@ -267,10 +297,20 @@ func RunStdioServer(cfg StdioServerConfig) error { tokenScopes = fetchedScopes logger.Info("token scopes fetched for filtering", "scopes", tokenScopes) } - } else { + case cfg.OAuthManager != nil: + tokenScopes = cfg.OAuthScopes + logger.Info("using requested OAuth scopes for tool filtering", "scopes", tokenScopes) + default: logger.Debug("skipping scope filtering for non-PAT token") } + // For OAuth, the token is resolved lazily: empty until the user authorizes + // on the first tool call, then refreshed for the rest of the session. + var tokenProvider func() string + if cfg.OAuthManager != nil { + tokenProvider = cfg.OAuthManager.AccessToken + } + ghServer, err := NewStdioMCPServer(ctx, github.MCPServerConfig{ Version: cfg.Version, Host: cfg.Host, @@ -287,11 +327,18 @@ func RunStdioServer(cfg StdioServerConfig) error { Logger: logger, RepoAccessTTL: cfg.RepoAccessCacheTTL, TokenScopes: tokenScopes, + TokenProvider: tokenProvider, }) if err != nil { return fmt.Errorf("failed to create MCP server: %w", err) } + // With OAuth, intercept tool calls to run the authorization flow on first + // use, before the handler tries to call GitHub with an empty token. + if cfg.OAuthManager != nil { + ghServer.AddReceivingMiddleware(createOAuthMiddleware(cfg.OAuthManager, logger)) + } + if cfg.ExportTranslations { // Once server is initialized, all translations are loaded dumpTranslations() diff --git a/pkg/github/server.go b/pkg/github/server.go index 7ec5837c3a..627cc678b2 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -68,6 +68,11 @@ type MCPServerConfig struct { // This is used for PAT scope filtering where we can't issue scope challenges. TokenScopes []string + // TokenProvider, when non-nil, supplies the GitHub token for each API + // request instead of the static Token. It backs OAuth login, where the + // token is obtained lazily on first use and refreshed thereafter. + TokenProvider func() string + // Additional server options to apply ServerOptions []MCPServerOption } diff --git a/pkg/http/transport/bearer.go b/pkg/http/transport/bearer.go index 66922bbdaa..9be3fd5342 100644 --- a/pkg/http/transport/bearer.go +++ b/pkg/http/transport/bearer.go @@ -11,11 +11,21 @@ import ( type BearerAuthTransport struct { Transport http.RoundTripper Token string + + // TokenProvider, when non-nil, supplies the bearer token for each request + // and takes precedence over Token. It backs OAuth, where the token is + // obtained after the client is built and is refreshed over the session's + // lifetime. It may return an empty string before authorization completes. + TokenProvider func() string } func (t *BearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { req = req.Clone(req.Context()) - req.Header.Set(headers.AuthorizationHeader, "Bearer "+t.Token) + token := t.Token + if t.TokenProvider != nil { + token = t.TokenProvider() + } + req.Header.Set(headers.AuthorizationHeader, "Bearer "+token) // Check for GraphQL-Features in context and add header if present if features := ghcontext.GetGraphQLFeatures(req.Context()); len(features) > 0 { diff --git a/pkg/http/transport/bearer_test.go b/pkg/http/transport/bearer_test.go new file mode 100644 index 0000000000..550144b866 --- /dev/null +++ b/pkg/http/transport/bearer_test.go @@ -0,0 +1,164 @@ +package transport + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/headers" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBearerAuthTransport(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + token string + tokenProvider func() string + wantAuth string + }{ + { + name: "static token", + token: "static-token", + wantAuth: "Bearer static-token", + }, + { + name: "token provider takes precedence over static token", + token: "static-token", + tokenProvider: func() string { return "provided-token" }, + wantAuth: "Bearer provided-token", + }, + { + name: "token provider with empty static token", + tokenProvider: func() string { return "provided-token" }, + wantAuth: "Bearer provided-token", + }, + { + name: "token provider may return empty before authorization", + tokenProvider: func() string { return "" }, + wantAuth: "Bearer", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get(headers.AuthorizationHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + Token: tc.token, + TokenProvider: tc.tokenProvider, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, tc.wantAuth, gotAuth) + }) + } +} + +// TestBearerAuthTransport_TokenProviderResolvedPerRequest verifies that the +// token provider is consulted on every request, so a token that arrives (or is +// refreshed) after the transport is constructed takes effect without rebuilding +// the client. This is the property OAuth relies on. +func TestBearerAuthTransport_TokenProviderResolvedPerRequest(t *testing.T) { + t.Parallel() + + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get(headers.AuthorizationHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + current := "" + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + TokenProvider: func() string { return current }, + } + + do := func() { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + } + + do() + assert.Equal(t, "Bearer", gotAuth, "no token yet before authorization") + + current = "first-token" + do() + assert.Equal(t, "Bearer first-token", gotAuth, "token picked up once available") + + current = "refreshed-token" + do() + assert.Equal(t, "Bearer refreshed-token", gotAuth, "refreshed token picked up") +} + +func TestBearerAuthTransport_PassesGraphQLFeaturesHeader(t *testing.T) { + t.Parallel() + + var gotFeatures string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotFeatures = r.Header.Get(headers.GraphQLFeaturesHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + Token: "token", + } + + ctx := ghcontext.WithGraphQLFeatures(context.Background(), "feature1", "feature2") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, "feature1, feature2", gotFeatures) +} + +func TestBearerAuthTransport_DoesNotMutateOriginalRequest(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + Token: "token", + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Empty(t, req.Header.Get(headers.AuthorizationHeader), "original request must not be mutated") +} From 622d429004373b3f508ed1ccbef164a1c5973e84 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Tue, 16 Jun 2026 19:14:33 +0200 Subject: [PATCH 2/4] =?UTF-8?q?refactor(oauth):=20address=20review=20?= =?UTF-8?q?=E2=80=94=20omit=20empty=20bearer=20header,=20guard=20token/oau?= =?UTF-8?q?th?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - BearerAuthTransport omits the Authorization header entirely when the token is empty (pre-authorization) rather than sending an empty "Bearer " value. - RunStdioServer rejects the ambiguous combination of a static Token and an OAuthManager up front, enforcing the documented mutual exclusivity. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- internal/ghmcp/oauth_test.go | 18 +++++++++++++++++- internal/ghmcp/server.go | 7 +++++++ pkg/http/transport/bearer.go | 6 +++++- pkg/http/transport/bearer_test.go | 4 ++-- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/internal/ghmcp/oauth_test.go b/internal/ghmcp/oauth_test.go index 4f370cf7bc..826b000185 100644 --- a/internal/ghmcp/oauth_test.go +++ b/internal/ghmcp/oauth_test.go @@ -286,6 +286,22 @@ func TestCreateOAuthMiddleware(t *testing.T) { }) } +// TestRunStdioServerRejectsTokenAndOAuth verifies the mutually-exclusive guard: +// supplying both a static token and an OAuth manager is rejected before the +// server starts, rather than silently preferring one for auth and the other for +// scope filtering. +func TestRunStdioServerRejectsTokenAndOAuth(t *testing.T) { + t.Parallel() + + mgr := oauth.NewManager(oauth.NewGitHubConfig("client-id", "", nil, "", 0), discardLogger()) + err := RunStdioServer(StdioServerConfig{ + Token: "ghp_static", + OAuthManager: mgr, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "mutually exclusive") +} + // TestCreateGitHubClientsTokenProvider proves the OAuth wiring: when a // TokenProvider is configured the REST client authenticates with the provider's // current token on every request (and never pins a stale one), which is what the @@ -317,7 +333,7 @@ func TestCreateGitHubClientsTokenProvider(t *testing.T) { } do() - assert.Equal(t, "Bearer", gotAuth, "no token before authorization") + assert.Equal(t, "", gotAuth, "no auth header before authorization") current = "oauth-token" do() diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 2364b02688..1bf84453c8 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -261,6 +261,13 @@ type StdioServerConfig struct { // RunStdioServer is not concurrent safe. func RunStdioServer(cfg StdioServerConfig) error { + // OAuth login and a static token are mutually exclusive: they would + // disagree on how the token is sourced (lazy provider vs. static) and on + // scope filtering, so reject the ambiguous combination up front. + if cfg.OAuthManager != nil && cfg.Token != "" { + return fmt.Errorf("OAuthManager and a static Token are mutually exclusive: provide one or the other") + } + // Create app context ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() diff --git a/pkg/http/transport/bearer.go b/pkg/http/transport/bearer.go index 9be3fd5342..0c12ddfc91 100644 --- a/pkg/http/transport/bearer.go +++ b/pkg/http/transport/bearer.go @@ -25,7 +25,11 @@ func (t *BearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, erro if t.TokenProvider != nil { token = t.TokenProvider() } - req.Header.Set(headers.AuthorizationHeader, "Bearer "+token) + // Before OAuth authorization completes the token is empty; send an + // unauthenticated request rather than an empty "Bearer " header. + if token != "" { + req.Header.Set(headers.AuthorizationHeader, "Bearer "+token) + } // Check for GraphQL-Features in context and add header if present if features := ghcontext.GetGraphQLFeatures(req.Context()); len(features) > 0 { diff --git a/pkg/http/transport/bearer_test.go b/pkg/http/transport/bearer_test.go index 550144b866..76ef8686cd 100644 --- a/pkg/http/transport/bearer_test.go +++ b/pkg/http/transport/bearer_test.go @@ -41,7 +41,7 @@ func TestBearerAuthTransport(t *testing.T) { { name: "token provider may return empty before authorization", tokenProvider: func() string { return "" }, - wantAuth: "Bearer", + wantAuth: "", }, } @@ -103,7 +103,7 @@ func TestBearerAuthTransport_TokenProviderResolvedPerRequest(t *testing.T) { } do() - assert.Equal(t, "Bearer", gotAuth, "no token yet before authorization") + assert.Equal(t, "", gotAuth, "no auth header before authorization") current = "first-token" do() From 2b4d5e60a67c7bc54821102e80fe8204fa314bb0 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Thu, 18 Jun 2026 10:57:54 +0200 Subject: [PATCH 3/4] docs(oauth): clarify SupportedScopes is the stdio default and tool filter Document that stdio OAuth login requests these scopes by default and then filters the exposed tools to the scopes actually granted, so a tool whose required scope is absent from this list is hidden under default OAuth even though a PAT carrying that scope would expose it. Keep the list in sync with tool scope requirements when scopes change. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/http/oauth/oauth.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index ffa7669a9d..f7ffe67e6b 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -19,7 +19,13 @@ const ( OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource" ) -// SupportedScopes lists all OAuth scopes that may be required by MCP tools. +// SupportedScopes lists every OAuth scope that an MCP tool may require. It is the +// source of truth in two places: HTTP mode advertises it as scopes_supported in +// the protected-resource metadata, and stdio OAuth login requests it by default +// and then filters the exposed tools to the granted scopes. A tool whose required +// scope is absent here is therefore hidden under default OAuth even though a PAT +// carrying that scope would expose it, so keep this list in sync with tool scope +// requirements when scopes change. var SupportedScopes = []string{ "repo", "read:org", From b7e81b854e05b8e797eef4b9371b2fd1cbeba059 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Thu, 25 Jun 2026 21:25:50 +0200 Subject: [PATCH 4/4] Distinguish undeliverable auth prompts from user declines An elicitation prompt that the client cannot deliver (a transport or protocol failure) was treated the same as a user actively declining: any display error cancelled the flow. That conflated a system failure with a deliberate "no", so a client that advertised URL elicitation but failed to deliver it would hard-fail the login instead of degrading. Add an ErrPromptUnavailable sentinel alongside ErrPromptDeclined and have the MCP adapter return it when Elicit fails at the transport level. The manager now falls back to the manual user-action channel on an undeliverable prompt (keeping the background flow alive so the user can still authorize out of band), while a genuine decline still aborts. A context-cancelled prompt is checked first so an ending flow is never misread as a transport failure. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- internal/ghmcp/oauth.go | 9 ++++-- internal/ghmcp/oauth_test.go | 46 +++++++++++++++++++++++++++++ internal/oauth/flow.go | 54 +++++++++++++++++++++------------- internal/oauth/manager.go | 52 +++++++++++++++++++++++++++++--- internal/oauth/manager_test.go | 36 +++++++++++++++++++++++ internal/oauth/prompter.go | 18 +++++++++--- 6 files changed, 185 insertions(+), 30 deletions(-) diff --git a/internal/ghmcp/oauth.go b/internal/ghmcp/oauth.go index 6a1d388956..abc6d3d11c 100644 --- a/internal/ghmcp/oauth.go +++ b/internal/ghmcp/oauth.go @@ -44,7 +44,10 @@ func (p *sessionPrompter) PromptURL(ctx context.Context, prompt oauth.Prompt) er ElicitationID: rand.Text(), }) if err != nil { - return err + // The client advertised URL elicitation but the request itself failed: + // classify it as undeliverable (not a user decision) so the flow can fall + // back to a channel that needs no client capability. + return fmt.Errorf("%w: %w", oauth.ErrPromptUnavailable, err) } if res.Action != "accept" { return oauth.ErrPromptDeclined @@ -71,7 +74,9 @@ func (p *sessionPrompter) PromptForm(ctx context.Context, prompt oauth.Prompt) e Message: prompt.Message, }) if err != nil { - return err + // As with PromptURL, a delivery failure is undeliverable rather than a + // decline, so the flow can fall back instead of aborting. + return fmt.Errorf("%w: %w", oauth.ErrPromptUnavailable, err) } if res.Action != "accept" { return oauth.ErrPromptDeclined diff --git a/internal/ghmcp/oauth_test.go b/internal/ghmcp/oauth_test.go index 826b000185..732d080e40 100644 --- a/internal/ghmcp/oauth_test.go +++ b/internal/ghmcp/oauth_test.go @@ -2,6 +2,7 @@ package ghmcp import ( "context" + "errors" "io" "log/slog" "net/http" @@ -193,6 +194,51 @@ func TestSessionPrompterPromptActions(t *testing.T) { } } +// TestSessionPrompterTransportError verifies that a prompt which fails to be +// delivered (the client errors instead of returning an action) is reported as +// ErrPromptUnavailable, not ErrPromptDeclined. The manager relies on this +// distinction to fall back to manual instructions instead of aborting. +func TestSessionPrompterTransportError(t *testing.T) { + t.Parallel() + + caps := &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{ + URL: &mcp.URLElicitationCapabilities{}, + Form: &mcp.FormElicitationCapabilities{}, + }} + + for _, mode := range []string{"url", "form"} { + t.Run(mode, func(t *testing.T) { + t.Parallel() + + handler := func(_ context.Context, _ *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return nil, errors.New("client cannot deliver elicitation") + } + + got := runProbe(t, caps, handler, func(ctx context.Context, p *sessionPrompter) string { + var err error + if mode == "url" { + err = p.PromptURL(ctx, oauth.Prompt{Message: "msg", URL: "https://example.com/auth"}) + } else { + err = p.PromptForm(ctx, oauth.Prompt{Message: "msg"}) + } + switch { + case err == nil: + return "ok" + case errors.Is(err, oauth.ErrPromptDeclined): + return "declined" + case errors.Is(err, oauth.ErrPromptUnavailable): + return "unavailable" + default: + return "error: " + err.Error() + } + }) + + assert.Equal(t, "unavailable", got, + "a delivery failure must be classified as undeliverable, not a decline") + }) + } +} + // fakeAuthenticator is a deterministic stand-in for *oauth.Manager that lets the // middleware be tested at each branch without standing up live GitHub flows. type fakeAuthenticator struct { diff --git a/internal/oauth/flow.go b/internal/oauth/flow.go index 76d0298951..fda1dda19f 100644 --- a/internal/oauth/flow.go +++ b/internal/oauth/flow.go @@ -24,9 +24,15 @@ type flowPlan struct { // poll the device endpoint) and returns the token. run func(context.Context) (*oauth2.Token, error) // display, if set, presents the prompt to the user via the Prompter and - // blocks until they act. A non-nil error (including ErrPromptDeclined) - // aborts the flow. + // blocks until they act. ErrPromptDeclined (the user said no) or any other + // error aborts the flow, except ErrPromptUnavailable, which degrades to + // fallback when that is set. display func(context.Context) error + // fallback, if set alongside display, is the manual user action to surface + // when the display prompt cannot be delivered (ErrPromptUnavailable). It lets + // a runtime elicitation failure degrade to the manual channel — keeping the + // background flow alive — instead of aborting. + fallback *UserAction // userAction, if set, indicates the last-resort channel: the caller must // surface it and the user retries after authorizing out of band. userAction *UserAction @@ -124,6 +130,16 @@ func (m *Manager) beginPKCE(prompter Prompter) (*flowPlan, error) { m.logger.Debug("browser auto-open unavailable", "reason", browserErr) } + // The manual instructions double as the fallback if a chosen display channel + // turns out to be undeliverable at runtime, so build them once here. + manual := &UserAction{ + URL: authURL, + Message: fmt.Sprintf( + "To authorize the GitHub MCP Server, open this URL in your browser:\n\n%s\n\nAfter authorizing, retry your request.\n\n%s", + authURL, securityAdvisory, + ), + } + if canPromptURL(prompter) { display := func(ctx context.Context) error { return prompter.PromptURL(ctx, Prompt{ @@ -131,16 +147,10 @@ func (m *Manager) beginPKCE(prompter Prompter) (*flowPlan, error) { URL: authURL, }) } - return &flowPlan{run: run, display: display}, nil + return &flowPlan{run: run, display: display, fallback: manual}, nil } - return &flowPlan{run: run, userAction: &UserAction{ - URL: authURL, - Message: fmt.Sprintf( - "To authorize the GitHub MCP Server, open this URL in your browser:\n\n%s\n\nAfter authorizing, retry your request.\n\n%s", - authURL, securityAdvisory, - ), - }}, nil + return &flowPlan{run: run, userAction: manual}, nil } // beginDevice prepares the device authorization flow. It requests a device code @@ -164,6 +174,17 @@ func (m *Manager) beginDevice(prompter Prompter) (*flowPlan, error) { return tok, nil } + // As with PKCE, the manual instructions double as the runtime fallback, so + // build them once and reuse for both display plans and the last resort. + manual := &UserAction{ + URL: da.VerificationURI, + UserCode: da.UserCode, + Message: fmt.Sprintf( + "%s\n\nAfter authorizing, retry your request.\n\n%s", + deviceInstruction(da), securityAdvisory, + ), + } + if canPromptURL(prompter) { display := func(ctx context.Context) error { return prompter.PromptURL(ctx, Prompt{ @@ -172,7 +193,7 @@ func (m *Manager) beginDevice(prompter Prompter) (*flowPlan, error) { UserCode: da.UserCode, }) } - return &flowPlan{run: run, display: display}, nil + return &flowPlan{run: run, display: display, fallback: manual}, nil } if canPromptForm(prompter) { @@ -183,17 +204,10 @@ func (m *Manager) beginDevice(prompter Prompter) (*flowPlan, error) { UserCode: da.UserCode, }) } - return &flowPlan{run: run, display: display}, nil + return &flowPlan{run: run, display: display, fallback: manual}, nil } - return &flowPlan{run: run, userAction: &UserAction{ - URL: da.VerificationURI, - UserCode: da.UserCode, - Message: fmt.Sprintf( - "%s\n\nAfter authorizing, retry your request.\n\n%s", - deviceInstruction(da), securityAdvisory, - ), - }}, nil + return &flowPlan{run: run, userAction: manual}, nil } // securityAdvisory nudges users on clients without URL elicitation to ask their diff --git a/internal/oauth/manager.go b/internal/oauth/manager.go index 8d1fe0f302..a78e919df7 100644 --- a/internal/oauth/manager.go +++ b/internal/oauth/manager.go @@ -190,14 +190,32 @@ func (m *Manager) Authenticate(ctx context.Context, prompter Prompter) (*Outcome } // runFlow executes a prepared flow in the background and records the result. The -// optional display prompt runs concurrently; if it ends in error or decline it -// cancels the flow. +// optional display prompt runs concurrently: a decline (or other failure) aborts +// the flow, while an undeliverable prompt degrades to the manual fallback without +// tearing the flow down, so the user can still authorize out of band. func (m *Manager) runFlow(ctx context.Context, cancel context.CancelFunc, plan *flowPlan) { defer cancel() if plan.display != nil { go func() { - if err := plan.display(ctx); err != nil { + err := plan.display(ctx) + switch { + case err == nil: + // Prompt shown; the flow completes when the token arrives. + case ctx.Err() != nil: + // The flow is already ending (timed out or cancelled elsewhere), + // so there is nothing to fall back to. Checking this before the + // fallback also prevents misreading a context-cancelled prompt as + // a transport failure. + case errors.Is(err, ErrPromptUnavailable) && plan.fallback != nil: + // The client advertised the capability but could not deliver the + // prompt. Surface the manual instructions instead of failing, and + // keep the background flow alive so the user can still authorize. + m.logger.Debug("authorization prompt undeliverable; falling back to manual instructions", "reason", err) + m.fallBackToUserAction(plan.fallback) + default: + // A user decline (ErrPromptDeclined) or any other prompt failure + // ends the flow. m.logger.Debug("authorization prompt closed", "reason", err) cancel() } @@ -208,6 +226,26 @@ func (m *Manager) runFlow(ctx context.Context, cancel context.CancelFunc, plan * m.complete(tok, err) } +// fallBackToUserAction promotes a running secure flow to the manual user-action +// channel after its prompt could not be delivered. The background flow keeps +// running, so the user can complete authorization out of band and retry. It is a +// no-op if the flow has already resolved. +func (m *Manager) fallBackToUserAction(ua *UserAction) { + m.mu.Lock() + defer m.mu.Unlock() + if m.status != statusInProgress { + return + } + m.status = statusAwaitingUser + m.pending = ua + // Wake any callers joined on this flow so they receive the action, and clear + // done so complete() does not double-close it when run() later finishes. + if m.done != nil { + close(m.done) + m.done = nil + } +} + // complete records the flow result, installing a refreshing token source on // success, and wakes any joined callers. func (m *Manager) complete(tok *oauth2.Token, err error) { @@ -236,7 +274,9 @@ func (m *Manager) complete(tok *oauth2.Token, err error) { } } -// joinWait blocks until the running flow finishes or ctx is cancelled. +// joinWait blocks until the running flow finishes or ctx is cancelled. If the +// flow was promoted to the manual channel while waiting (its prompt could not be +// delivered), it returns that user action rather than an error. func (m *Manager) joinWait(ctx context.Context, done chan struct{}) (*Outcome, error) { select { case <-done: @@ -244,8 +284,12 @@ func (m *Manager) joinWait(ctx context.Context, done chan struct{}) (*Outcome, e return nil, nil } m.mu.Lock() + pending := m.pending err := m.lastErr m.mu.Unlock() + if pending != nil { + return &Outcome{UserAction: pending}, nil + } if err != nil { return nil, err } diff --git a/internal/oauth/manager_test.go b/internal/oauth/manager_test.go index fb8323246d..6f43c03ef9 100644 --- a/internal/oauth/manager_test.go +++ b/internal/oauth/manager_test.go @@ -117,6 +117,42 @@ func TestAuthenticateDeclinedPromptFails(t *testing.T) { assert.Empty(t, m.AccessToken()) } +func TestAuthenticateUndeliverablePromptFallsBack(t *testing.T) { + f := newFakeGitHub(t) + m := newManager(t, f) + m.openURL = func(string) error { return errors.New("no browser") } + + // The client advertised URL elicitation but delivering the prompt fails (a + // transport/protocol error, not a user decision). This must degrade to the + // manual instructions rather than aborting like a decline does. + prompter := &fakePrompter{ + urlCapable: true, + onURL: func(_ context.Context, _ Prompt) error { + return ErrPromptUnavailable + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + out, err := m.Authenticate(ctx, prompter) + require.NoError(t, err, "an undeliverable prompt must not abort the flow") + require.NotNil(t, out) + require.NotNil(t, out.UserAction, "an undeliverable prompt must fall back to a user action") + assert.NotEmpty(t, out.UserAction.URL) + assert.Contains(t, out.UserAction.Message, securityAdvisory) + + // A concurrent retry while awaiting the user returns the same fallback action. + out2, err := m.Authenticate(ctx, nil) + require.NoError(t, err) + require.NotNil(t, out2.UserAction) + assert.Equal(t, out.UserAction.URL, out2.UserAction.URL) + + // The background flow stayed alive: opening the URL out of band completes it. + require.NoError(t, browserGet(out.UserAction.URL)) + assert.Equal(t, "gho_access", waitForToken(t, m)) +} + func TestAuthenticateLastDitchUserAction(t *testing.T) { f := newFakeGitHub(t) m := newManager(t, f) diff --git a/internal/oauth/prompter.go b/internal/oauth/prompter.go index 9225589367..1d6aa6ec61 100644 --- a/internal/oauth/prompter.go +++ b/internal/oauth/prompter.go @@ -5,10 +5,18 @@ import ( "errors" ) -// ErrPromptDeclined is returned by a Prompter when the user cancels or declines -// the authorization prompt. +// ErrPromptDeclined is returned by a Prompter when the user actively cancels or +// declines the authorization prompt. It is a deliberate "no", so the flow stops +// rather than falling back to another channel. var ErrPromptDeclined = errors.New("authorization declined by user") +// ErrPromptUnavailable is returned by a Prompter when the prompt could not be +// delivered at all — for example the client advertised an elicitation capability +// but the request failed at the transport or protocol level. Unlike +// ErrPromptDeclined it reflects no user decision, so the flow falls back to a +// channel that needs no client capability instead of giving up. +var ErrPromptUnavailable = errors.New("authorization prompt could not be delivered") + // Prompt is the content shown to the user when asking them to authorize. type Prompt struct { // Message is a human-readable instruction. @@ -36,7 +44,8 @@ type Prompter interface { // until the user acknowledges, declines, or ctx is done. Returning nil means // the prompt was shown (not that authorization completed); the caller waits // for the OAuth flow itself to finish. It returns ErrPromptDeclined if the - // user declines or cancels. + // user declines or cancels, or ErrPromptUnavailable if the prompt could not + // be delivered. PromptURL(ctx context.Context, p Prompt) error // CanPromptForm reports whether the client supports form elicitation, used @@ -44,7 +53,8 @@ type Prompter interface { CanPromptForm() bool // PromptForm presents a textual acknowledgement prompt and blocks until the - // user responds. It returns ErrPromptDeclined if the user declines. + // user responds. It returns ErrPromptDeclined if the user declines, or + // ErrPromptUnavailable if the prompt could not be delivered. PromptForm(ctx context.Context, p Prompt) error }