diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 604556692c..b329b5012d 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -8,8 +8,10 @@ import ( "time" "github.com/github/github-mcp-server/internal/ghmcp" + "github.com/github/github-mcp-server/internal/oauth" "github.com/github/github-mcp-server/pkg/github" ghhttp "github.com/github/github-mcp-server/pkg/http" + ghoauth "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -34,8 +36,9 @@ var ( Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`, RunE: func(_ *cobra.Command, _ []string) error { token := viper.GetString("personal_access_token") - if token == "" { - return errors.New("GITHUB_PERSONAL_ACCESS_TOKEN not set") + oauthClientID := viper.GetString("oauth-client-id") + if token == "" && oauthClientID == "" { + return errors.New("authentication required: set GITHUB_PERSONAL_ACCESS_TOKEN, or pass --oauth-client-id to log in via OAuth") } // If you're wondering why we're not using viper.GetStringSlice("toolsets"), @@ -95,6 +98,29 @@ var ( ExcludeTools: excludeTools, RepoAccessCacheTTL: &ttl, } + + // When no static token is provided, log in via OAuth using the given + // client. The requested scopes default to the full supported set + // (which filters out no tools); an explicit, narrower --oauth-scopes + // both narrows the grant and hides tools needing other scopes. + if token == "" { + scopes := ghoauth.SupportedScopes + if viper.IsSet("oauth-scopes") { + if err := viper.UnmarshalKey("oauth-scopes", &scopes); err != nil { + return fmt.Errorf("failed to unmarshal oauth-scopes: %w", err) + } + } + oauthConfig := oauth.NewGitHubConfig( + oauthClientID, + viper.GetString("oauth-client-secret"), + scopes, + viper.GetString("host"), + viper.GetInt("oauth-callback-port"), + ) + stdioServerConfig.OAuthManager = oauth.NewManager(oauthConfig, nil) + stdioServerConfig.OAuthScopes = scopes + } + return ghmcp.RunStdioServer(stdioServerConfig) }, } @@ -183,6 +209,14 @@ func init() { rootCmd.PersistentFlags().Bool("insiders", false, "Enable insiders features") rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") + // stdio-specific OAuth flags. Provide --oauth-client-id (instead of a token) + // to log in via the browser-based OAuth flow on first use. Works for both + // OAuth Apps and GitHub Apps. + stdioCmd.Flags().String("oauth-client-id", "", "OAuth App or GitHub App client ID, enabling interactive OAuth login when no token is set") + stdioCmd.Flags().String("oauth-client-secret", "", "OAuth client secret, if the app requires one (it is a public, non-confidential credential for distributed clients)") + stdioCmd.Flags().StringSlice("oauth-scopes", nil, "Comma-separated OAuth scopes to request; also filters tools to those scopes. Defaults to the full supported set") + stdioCmd.Flags().Int("oauth-callback-port", 0, "Fixed local port for the OAuth callback server. Defaults to a random port; set a fixed port when mapping it through Docker") + // HTTP-specific flags httpCmd.Flags().Int("port", 8082, "HTTP server port") httpCmd.Flags().String("listen-host", "", "Host the HTTP server binds to (e.g. 127.0.0.1). Empty binds to all interfaces.") @@ -205,6 +239,10 @@ func init() { _ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode")) _ = viper.BindPFlag("insiders", rootCmd.PersistentFlags().Lookup("insiders")) _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) + _ = viper.BindPFlag("oauth-client-id", stdioCmd.Flags().Lookup("oauth-client-id")) + _ = viper.BindPFlag("oauth-client-secret", stdioCmd.Flags().Lookup("oauth-client-secret")) + _ = viper.BindPFlag("oauth-scopes", stdioCmd.Flags().Lookup("oauth-scopes")) + _ = viper.BindPFlag("oauth-callback-port", stdioCmd.Flags().Lookup("oauth-callback-port")) _ = viper.BindPFlag("port", httpCmd.Flags().Lookup("port")) _ = viper.BindPFlag("listen-host", httpCmd.Flags().Lookup("listen-host")) _ = viper.BindPFlag("base-url", httpCmd.Flags().Lookup("base-url")) diff --git a/internal/ghmcp/oauth.go b/internal/ghmcp/oauth.go new file mode 100644 index 0000000000..abc6d3d11c --- /dev/null +++ b/internal/ghmcp/oauth.go @@ -0,0 +1,133 @@ +package ghmcp + +import ( + "context" + "crypto/rand" + "fmt" + "log/slog" + + "github.com/github/github-mcp-server/internal/oauth" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// sessionPrompter adapts an MCP server session to oauth.Prompter, presenting +// authorization prompts to the user via elicitation. Keeping the prompt on the +// MCP control channel (rather than a tool result) keeps the authorization URL +// and any session-bound state out of the model's context. +type sessionPrompter struct { + session *mcp.ServerSession +} + +// elicitationCaps returns the client's declared elicitation capabilities, or nil +// if the client did not advertise any. +func (p *sessionPrompter) elicitationCaps() *mcp.ElicitationCapabilities { + params := p.session.InitializeParams() + if params == nil || params.Capabilities == nil { + return nil + } + return params.Capabilities.Elicitation +} + +// CanPromptURL reports whether the client supports URL-mode elicitation. +func (p *sessionPrompter) CanPromptURL() bool { + caps := p.elicitationCaps() + return caps != nil && caps.URL != nil +} + +// PromptURL presents the authorization URL via URL-mode elicitation and blocks +// until the user acknowledges, declines, or ctx is done. +func (p *sessionPrompter) PromptURL(ctx context.Context, prompt oauth.Prompt) error { + res, err := p.session.Elicit(ctx, &mcp.ElicitParams{ + Mode: "url", + Message: prompt.Message, + URL: prompt.URL, + ElicitationID: rand.Text(), + }) + if err != nil { + // The client advertised URL elicitation but the request itself failed: + // classify it as undeliverable (not a user decision) so the flow can fall + // back to a channel that needs no client capability. + return fmt.Errorf("%w: %w", oauth.ErrPromptUnavailable, err) + } + if res.Action != "accept" { + return oauth.ErrPromptDeclined + } + return nil +} + +// CanPromptForm reports whether the client supports form-mode elicitation. The +// SDK treats a client that advertises neither form nor URL capabilities as +// supporting forms, for backward compatibility, so we mirror that here. +func (p *sessionPrompter) CanPromptForm() bool { + caps := p.elicitationCaps() + if caps == nil { + return false + } + return caps.Form != nil || caps.URL == nil +} + +// PromptForm presents a textual acknowledgement (used to display a device code +// when URL elicitation is unavailable) and blocks until the user responds. +func (p *sessionPrompter) PromptForm(ctx context.Context, prompt oauth.Prompt) error { + res, err := p.session.Elicit(ctx, &mcp.ElicitParams{ + Mode: "form", + Message: prompt.Message, + }) + if err != nil { + // As with PromptURL, a delivery failure is undeliverable rather than a + // decline, so the flow can fall back instead of aborting. + return fmt.Errorf("%w: %w", oauth.ErrPromptUnavailable, err) + } + if res.Action != "accept" { + return oauth.ErrPromptDeclined + } + return nil +} + +// oauthAuthenticator is the subset of *oauth.Manager that the middleware needs. +// Depending on the interface (rather than the concrete manager) lets the +// middleware be exercised with a deterministic fake, since driving the real +// manager to its branches would require standing up live GitHub flows. +type oauthAuthenticator interface { + HasToken() bool + Authenticate(ctx context.Context, prompter oauth.Prompter) (*oauth.Outcome, error) +} + +// createOAuthMiddleware returns receiving middleware that authorizes the session +// lazily, on the first tool call. Authorization is deferred until here (rather +// than at startup) because the prompts depend on an initialized session whose +// elicitation capabilities are known. +// +// When a token is already available the call proceeds untouched. Otherwise the +// flow runs: secure channels (browser, URL elicitation) block until the token +// arrives and then the call proceeds; the last-resort channel returns the +// instruction to the user as a tool result and asks them to retry. +func createOAuthMiddleware(mgr oauthAuthenticator, logger *slog.Logger) func(next mcp.MethodHandler) mcp.MethodHandler { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, request mcp.Request) (mcp.Result, error) { + if method != "tools/call" || mgr.HasToken() { + return next(ctx, method, request) + } + + callReq, ok := request.(*mcp.CallToolRequest) + if !ok { + return next(ctx, method, request) + } + + outcome, err := mgr.Authenticate(ctx, &sessionPrompter{session: callReq.Session}) + if err != nil { + return nil, fmt.Errorf("github authorization failed: %w", err) + } + if outcome != nil && outcome.UserAction != nil { + logger.Info("surfacing github authorization instructions to user") + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: outcome.UserAction.Message}}, + }, nil + } + return next(ctx, method, request) + } + } +} + +// ensure sessionPrompter satisfies the Prompter contract. +var _ oauth.Prompter = (*sessionPrompter)(nil) diff --git a/internal/ghmcp/oauth_test.go b/internal/ghmcp/oauth_test.go new file mode 100644 index 0000000000..732d080e40 --- /dev/null +++ b/internal/ghmcp/oauth_test.go @@ -0,0 +1,391 @@ +package ghmcp + +import ( + "context" + "errors" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/github/github-mcp-server/internal/oauth" + "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +// probeToolName is the name of the throwaway tool the harness registers; its +// handler runs a probe closure against a sessionPrompter so the adapter can be +// exercised against a real, fully-negotiated server session from the client side. +const probeToolName = "probe" + +// runProbe stands up an in-memory MCP client/server pair, registers a tool whose +// handler runs probe against a sessionPrompter wrapping the live server session, +// and returns the text the probe produced. The client is configured with the +// given capabilities and elicitation handler so the adapter sees a real, +// fully-negotiated session rather than a hand-built fake. +func runProbe( + t *testing.T, + clientCaps *mcp.ClientCapabilities, + elicitationHandler func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error), + probe func(context.Context, *sessionPrompter) string, +) string { + t.Helper() + + server := mcp.NewServer(&mcp.Implementation{Name: "test-server", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: probeToolName}, func(ctx context.Context, req *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, any, error) { + text := probe(ctx, &sessionPrompter{session: req.Session}) + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: text}}}, nil, nil + }) + + st, ct := mcp.NewInMemoryTransports() + + ss, err := server.Connect(context.Background(), st, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = ss.Close() }) + + client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "v0.0.1"}, &mcp.ClientOptions{ + Capabilities: clientCaps, + ElicitationHandler: elicitationHandler, + }) + cs, err := client.Connect(context.Background(), ct, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = cs.Close() }) + + res, err := cs.CallTool(context.Background(), &mcp.CallToolParams{Name: probeToolName}) + require.NoError(t, err) + require.Len(t, res.Content, 1) + text, ok := res.Content[0].(*mcp.TextContent) + require.True(t, ok, "probe result should be text content") + return text.Text +} + +func TestSessionPrompterCapabilities(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + caps *mcp.ClientCapabilities + wantURL bool + wantForm bool + }{ + { + name: "no elicitation advertised", + caps: &mcp.ClientCapabilities{}, + wantURL: false, + wantForm: false, + }, + { + name: "url only", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{URL: &mcp.URLElicitationCapabilities{}}}, + wantURL: true, + wantForm: false, + }, + { + name: "form only", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{Form: &mcp.FormElicitationCapabilities{}}}, + wantURL: false, + wantForm: true, + }, + { + name: "url and form", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{URL: &mcp.URLElicitationCapabilities{}, Form: &mcp.FormElicitationCapabilities{}}}, + wantURL: true, + wantForm: true, + }, + { + name: "empty elicitation capability implies form for backward compatibility", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{}}, + wantURL: false, + wantForm: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := runProbe(t, tc.caps, nil, func(_ context.Context, p *sessionPrompter) string { + if p.CanPromptURL() { + if p.CanPromptForm() { + return "url+form" + } + return "url" + } + if p.CanPromptForm() { + return "form" + } + return "none" + }) + + want := "none" + switch { + case tc.wantURL && tc.wantForm: + want = "url+form" + case tc.wantURL: + want = "url" + case tc.wantForm: + want = "form" + } + assert.Equal(t, want, got) + }) + } +} + +func TestSessionPrompterPromptActions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + action string + wantDecline bool + }{ + {name: "accept", action: "accept", wantDecline: false}, + {name: "decline", action: "decline", wantDecline: true}, + {name: "cancel", action: "cancel", wantDecline: true}, + } + + caps := &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{ + URL: &mcp.URLElicitationCapabilities{}, + Form: &mcp.FormElicitationCapabilities{}, + }} + + for _, tc := range tests { + // URL and form modes share the accept/decline mapping; cover both. + for _, mode := range []string{"url", "form"} { + t.Run(tc.name+"/"+mode, func(t *testing.T) { + t.Parallel() + + handler := func(_ context.Context, _ *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return &mcp.ElicitResult{Action: tc.action}, nil + } + + got := runProbe(t, caps, handler, func(ctx context.Context, p *sessionPrompter) string { + var err error + if mode == "url" { + err = p.PromptURL(ctx, oauth.Prompt{Message: "msg", URL: "https://example.com/auth"}) + } else { + err = p.PromptForm(ctx, oauth.Prompt{Message: "msg"}) + } + if err == nil { + return "ok" + } + if err == oauth.ErrPromptDeclined { + return "declined" + } + return "error: " + err.Error() + }) + + if tc.wantDecline { + assert.Equal(t, "declined", got) + } else { + assert.Equal(t, "ok", got) + } + }) + } + } +} + +// TestSessionPrompterTransportError verifies that a prompt which fails to be +// delivered (the client errors instead of returning an action) is reported as +// ErrPromptUnavailable, not ErrPromptDeclined. The manager relies on this +// distinction to fall back to manual instructions instead of aborting. +func TestSessionPrompterTransportError(t *testing.T) { + t.Parallel() + + caps := &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{ + URL: &mcp.URLElicitationCapabilities{}, + Form: &mcp.FormElicitationCapabilities{}, + }} + + for _, mode := range []string{"url", "form"} { + t.Run(mode, func(t *testing.T) { + t.Parallel() + + handler := func(_ context.Context, _ *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return nil, errors.New("client cannot deliver elicitation") + } + + got := runProbe(t, caps, handler, func(ctx context.Context, p *sessionPrompter) string { + var err error + if mode == "url" { + err = p.PromptURL(ctx, oauth.Prompt{Message: "msg", URL: "https://example.com/auth"}) + } else { + err = p.PromptForm(ctx, oauth.Prompt{Message: "msg"}) + } + switch { + case err == nil: + return "ok" + case errors.Is(err, oauth.ErrPromptDeclined): + return "declined" + case errors.Is(err, oauth.ErrPromptUnavailable): + return "unavailable" + default: + return "error: " + err.Error() + } + }) + + assert.Equal(t, "unavailable", got, + "a delivery failure must be classified as undeliverable, not a decline") + }) + } +} + +// fakeAuthenticator is a deterministic stand-in for *oauth.Manager that lets the +// middleware be tested at each branch without standing up live GitHub flows. +type fakeAuthenticator struct { + hasToken bool + outcome *oauth.Outcome + err error + authCalls int + lastPrompter oauth.Prompter +} + +func (f *fakeAuthenticator) HasToken() bool { return f.hasToken } + +func (f *fakeAuthenticator) Authenticate(_ context.Context, prompter oauth.Prompter) (*oauth.Outcome, error) { + f.authCalls++ + f.lastPrompter = prompter + return f.outcome, f.err +} + +func TestCreateOAuthMiddleware(t *testing.T) { + t.Parallel() + + const nextText = "handler-ran" + newNext := func(called *bool) mcp.MethodHandler { + return func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + *called = true + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: nextText}}}, nil + } + } + + t.Run("non tool call passes through without authenticating", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: false} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + _, err := mw(newNext(&called))(context.Background(), "initialize", &mcp.InitializeRequest{}) + require.NoError(t, err) + assert.True(t, called, "next should run") + assert.Zero(t, fake.authCalls, "authentication must not run for non tool calls") + }) + + t.Run("existing token short circuits authentication", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: true} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + _, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.NoError(t, err) + assert.True(t, called, "next should run") + assert.Zero(t, fake.authCalls, "authentication must be skipped when a token already exists") + }) + + t.Run("successful authentication proceeds to handler", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: false, outcome: nil, err: nil} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + res, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.NoError(t, err) + assert.Equal(t, 1, fake.authCalls) + assert.True(t, called, "next should run once authorized") + callRes, ok := res.(*mcp.CallToolResult) + require.True(t, ok) + require.Len(t, callRes.Content, 1) + assert.Equal(t, nextText, callRes.Content[0].(*mcp.TextContent).Text) + }) + + t.Run("pending user action is surfaced as a tool result", func(t *testing.T) { + t.Parallel() + const message = "Open https://example.com/auth to authorize, then retry." + fake := &fakeAuthenticator{hasToken: false, outcome: &oauth.Outcome{UserAction: &oauth.UserAction{Message: message}}} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + res, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.NoError(t, err) + assert.False(t, called, "next must not run while the user still needs to authorize") + callRes, ok := res.(*mcp.CallToolResult) + require.True(t, ok) + require.Len(t, callRes.Content, 1) + assert.Equal(t, message, callRes.Content[0].(*mcp.TextContent).Text) + }) + + t.Run("authentication error is returned", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: false, err: assert.AnError} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + _, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.Error(t, err) + assert.ErrorIs(t, err, assert.AnError) + assert.False(t, called, "next must not run when authentication fails") + }) +} + +// TestRunStdioServerRejectsTokenAndOAuth verifies the mutually-exclusive guard: +// supplying both a static token and an OAuth manager is rejected before the +// server starts, rather than silently preferring one for auth and the other for +// scope filtering. +func TestRunStdioServerRejectsTokenAndOAuth(t *testing.T) { + t.Parallel() + + mgr := oauth.NewManager(oauth.NewGitHubConfig("client-id", "", nil, "", 0), discardLogger()) + err := RunStdioServer(StdioServerConfig{ + Token: "ghp_static", + OAuthManager: mgr, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "mutually exclusive") +} + +// TestCreateGitHubClientsTokenProvider proves the OAuth wiring: when a +// TokenProvider is configured the REST client authenticates with the provider's +// current token on every request (and never pins a stale one), which is what the +// lazy, refreshing OAuth token depends on. +func TestCreateGitHubClientsTokenProvider(t *testing.T) { + t.Parallel() + + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get(headers.AuthorizationHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + current := "" + apiHost, err := utils.NewAPIHost(server.URL) + require.NoError(t, err) + + clients, err := createGitHubClients(github.MCPServerConfig{ + Version: "test", + TokenProvider: func() string { return current }, + }, apiHost) + require.NoError(t, err) + + do := func() { + resp, err := clients.rest.Client().Get(server.URL) + require.NoError(t, err) + defer resp.Body.Close() + } + + do() + assert.Equal(t, "", gotAuth, "no auth header before authorization") + + current = "oauth-token" + do() + assert.Equal(t, "Bearer oauth-token", gotAuth, "provider token used once available") + + current = "refreshed-token" + do() + assert.Equal(t, "Bearer refreshed-token", gotAuth, "refreshed provider token used") +} diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index a37c4d940d..1bf84453c8 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -12,6 +12,7 @@ import ( "syscall" "time" + "github.com/github/github-mcp-server/internal/oauth" "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/transport" @@ -61,16 +62,30 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv return nil, fmt.Errorf("failed to get Raw URL: %w", err) } - // Construct REST client + // Construct REST client. When a TokenProvider is configured (OAuth), we + // authenticate via BearerAuthTransport and skip go-github's WithAuthToken: + // the latter installs its own round tripper that would pin the static token + // and shadow the dynamic one. restUATransport := &transport.UserAgentTransport{ Transport: http.DefaultTransport, Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version), } - restClient, err := gogithub.NewClient( - gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), - gogithub.WithAuthToken(cfg.Token), - gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), - ) + var restClient *gogithub.Client + if cfg.TokenProvider != nil { + restClient, err = gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{Transport: &transport.BearerAuthTransport{ + Transport: restUATransport, + TokenProvider: cfg.TokenProvider, + }}), + gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), + ) + } else { + restClient, err = gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), + gogithub.WithAuthToken(cfg.Token), + gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), + ) + } if err != nil { return nil, fmt.Errorf("failed to create REST client: %w", err) } @@ -82,7 +97,8 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv Transport: &transport.GraphQLFeaturesTransport{ Transport: http.DefaultTransport, }, - Token: cfg.Token, + Token: cfg.Token, + TokenProvider: cfg.TokenProvider, }, } @@ -229,10 +245,29 @@ type StdioServerConfig struct { // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. RepoAccessCacheTTL *time.Duration + + // OAuthManager, when non-nil, enables OAuth 2.1 login for stdio mode. The + // server starts without a token and runs the authorization flow on the + // first tool call (see createOAuthMiddleware). It is mutually exclusive with + // a static Token. + OAuthManager *oauth.Manager + + // OAuthScopes are the scopes requested during OAuth login. They double as + // the scope set for tool filtering: tools requiring a scope outside this set + // are hidden. The default set is the full supported list, which hides + // nothing; an explicit, narrower list filters accordingly. + OAuthScopes []string } // RunStdioServer is not concurrent safe. func RunStdioServer(cfg StdioServerConfig) error { + // OAuth login and a static token are mutually exclusive: they would + // disagree on how the token is sourced (lazy provider vs. static) and on + // scope filtering, so reject the ambiguous combination up front. + if cfg.OAuthManager != nil && cfg.Token != "" { + return fmt.Errorf("OAuthManager and a static Token are mutually exclusive: provide one or the other") + } + // Create app context ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() @@ -255,11 +290,13 @@ func RunStdioServer(cfg StdioServerConfig) error { logger := slog.New(slogHandler) logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) - // Fetch token scopes for scope-based tool filtering (PAT tokens only) - // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. - // Fine-grained PATs and other token types don't support this, so we skip filtering. + // Determine the scope set used to filter tools. Classic PATs expose their + // granted scopes via the API; OAuth uses the requested scopes (the default + // set hides nothing, a narrower explicit set filters accordingly). Other + // token types don't advertise scopes, so filtering is skipped. var tokenScopes []string - if strings.HasPrefix(cfg.Token, "ghp_") { + switch { + case strings.HasPrefix(cfg.Token, "ghp_"): fetchedScopes, err := fetchTokenScopesForHost(ctx, cfg.Token, cfg.Host) if err != nil { logger.Warn("failed to fetch token scopes, continuing without scope filtering", "error", err) @@ -267,10 +304,20 @@ func RunStdioServer(cfg StdioServerConfig) error { tokenScopes = fetchedScopes logger.Info("token scopes fetched for filtering", "scopes", tokenScopes) } - } else { + case cfg.OAuthManager != nil: + tokenScopes = cfg.OAuthScopes + logger.Info("using requested OAuth scopes for tool filtering", "scopes", tokenScopes) + default: logger.Debug("skipping scope filtering for non-PAT token") } + // For OAuth, the token is resolved lazily: empty until the user authorizes + // on the first tool call, then refreshed for the rest of the session. + var tokenProvider func() string + if cfg.OAuthManager != nil { + tokenProvider = cfg.OAuthManager.AccessToken + } + ghServer, err := NewStdioMCPServer(ctx, github.MCPServerConfig{ Version: cfg.Version, Host: cfg.Host, @@ -287,11 +334,18 @@ func RunStdioServer(cfg StdioServerConfig) error { Logger: logger, RepoAccessTTL: cfg.RepoAccessCacheTTL, TokenScopes: tokenScopes, + TokenProvider: tokenProvider, }) if err != nil { return fmt.Errorf("failed to create MCP server: %w", err) } + // With OAuth, intercept tool calls to run the authorization flow on first + // use, before the handler tries to call GitHub with an empty token. + if cfg.OAuthManager != nil { + ghServer.AddReceivingMiddleware(createOAuthMiddleware(cfg.OAuthManager, logger)) + } + if cfg.ExportTranslations { // Once server is initialized, all translations are loaded dumpTranslations() diff --git a/internal/oauth/flow.go b/internal/oauth/flow.go index 76d0298951..fda1dda19f 100644 --- a/internal/oauth/flow.go +++ b/internal/oauth/flow.go @@ -24,9 +24,15 @@ type flowPlan struct { // poll the device endpoint) and returns the token. run func(context.Context) (*oauth2.Token, error) // display, if set, presents the prompt to the user via the Prompter and - // blocks until they act. A non-nil error (including ErrPromptDeclined) - // aborts the flow. + // blocks until they act. ErrPromptDeclined (the user said no) or any other + // error aborts the flow, except ErrPromptUnavailable, which degrades to + // fallback when that is set. display func(context.Context) error + // fallback, if set alongside display, is the manual user action to surface + // when the display prompt cannot be delivered (ErrPromptUnavailable). It lets + // a runtime elicitation failure degrade to the manual channel — keeping the + // background flow alive — instead of aborting. + fallback *UserAction // userAction, if set, indicates the last-resort channel: the caller must // surface it and the user retries after authorizing out of band. userAction *UserAction @@ -124,6 +130,16 @@ func (m *Manager) beginPKCE(prompter Prompter) (*flowPlan, error) { m.logger.Debug("browser auto-open unavailable", "reason", browserErr) } + // The manual instructions double as the fallback if a chosen display channel + // turns out to be undeliverable at runtime, so build them once here. + manual := &UserAction{ + URL: authURL, + Message: fmt.Sprintf( + "To authorize the GitHub MCP Server, open this URL in your browser:\n\n%s\n\nAfter authorizing, retry your request.\n\n%s", + authURL, securityAdvisory, + ), + } + if canPromptURL(prompter) { display := func(ctx context.Context) error { return prompter.PromptURL(ctx, Prompt{ @@ -131,16 +147,10 @@ func (m *Manager) beginPKCE(prompter Prompter) (*flowPlan, error) { URL: authURL, }) } - return &flowPlan{run: run, display: display}, nil + return &flowPlan{run: run, display: display, fallback: manual}, nil } - return &flowPlan{run: run, userAction: &UserAction{ - URL: authURL, - Message: fmt.Sprintf( - "To authorize the GitHub MCP Server, open this URL in your browser:\n\n%s\n\nAfter authorizing, retry your request.\n\n%s", - authURL, securityAdvisory, - ), - }}, nil + return &flowPlan{run: run, userAction: manual}, nil } // beginDevice prepares the device authorization flow. It requests a device code @@ -164,6 +174,17 @@ func (m *Manager) beginDevice(prompter Prompter) (*flowPlan, error) { return tok, nil } + // As with PKCE, the manual instructions double as the runtime fallback, so + // build them once and reuse for both display plans and the last resort. + manual := &UserAction{ + URL: da.VerificationURI, + UserCode: da.UserCode, + Message: fmt.Sprintf( + "%s\n\nAfter authorizing, retry your request.\n\n%s", + deviceInstruction(da), securityAdvisory, + ), + } + if canPromptURL(prompter) { display := func(ctx context.Context) error { return prompter.PromptURL(ctx, Prompt{ @@ -172,7 +193,7 @@ func (m *Manager) beginDevice(prompter Prompter) (*flowPlan, error) { UserCode: da.UserCode, }) } - return &flowPlan{run: run, display: display}, nil + return &flowPlan{run: run, display: display, fallback: manual}, nil } if canPromptForm(prompter) { @@ -183,17 +204,10 @@ func (m *Manager) beginDevice(prompter Prompter) (*flowPlan, error) { UserCode: da.UserCode, }) } - return &flowPlan{run: run, display: display}, nil + return &flowPlan{run: run, display: display, fallback: manual}, nil } - return &flowPlan{run: run, userAction: &UserAction{ - URL: da.VerificationURI, - UserCode: da.UserCode, - Message: fmt.Sprintf( - "%s\n\nAfter authorizing, retry your request.\n\n%s", - deviceInstruction(da), securityAdvisory, - ), - }}, nil + return &flowPlan{run: run, userAction: manual}, nil } // securityAdvisory nudges users on clients without URL elicitation to ask their diff --git a/internal/oauth/manager.go b/internal/oauth/manager.go index 8d1fe0f302..a78e919df7 100644 --- a/internal/oauth/manager.go +++ b/internal/oauth/manager.go @@ -190,14 +190,32 @@ func (m *Manager) Authenticate(ctx context.Context, prompter Prompter) (*Outcome } // runFlow executes a prepared flow in the background and records the result. The -// optional display prompt runs concurrently; if it ends in error or decline it -// cancels the flow. +// optional display prompt runs concurrently: a decline (or other failure) aborts +// the flow, while an undeliverable prompt degrades to the manual fallback without +// tearing the flow down, so the user can still authorize out of band. func (m *Manager) runFlow(ctx context.Context, cancel context.CancelFunc, plan *flowPlan) { defer cancel() if plan.display != nil { go func() { - if err := plan.display(ctx); err != nil { + err := plan.display(ctx) + switch { + case err == nil: + // Prompt shown; the flow completes when the token arrives. + case ctx.Err() != nil: + // The flow is already ending (timed out or cancelled elsewhere), + // so there is nothing to fall back to. Checking this before the + // fallback also prevents misreading a context-cancelled prompt as + // a transport failure. + case errors.Is(err, ErrPromptUnavailable) && plan.fallback != nil: + // The client advertised the capability but could not deliver the + // prompt. Surface the manual instructions instead of failing, and + // keep the background flow alive so the user can still authorize. + m.logger.Debug("authorization prompt undeliverable; falling back to manual instructions", "reason", err) + m.fallBackToUserAction(plan.fallback) + default: + // A user decline (ErrPromptDeclined) or any other prompt failure + // ends the flow. m.logger.Debug("authorization prompt closed", "reason", err) cancel() } @@ -208,6 +226,26 @@ func (m *Manager) runFlow(ctx context.Context, cancel context.CancelFunc, plan * m.complete(tok, err) } +// fallBackToUserAction promotes a running secure flow to the manual user-action +// channel after its prompt could not be delivered. The background flow keeps +// running, so the user can complete authorization out of band and retry. It is a +// no-op if the flow has already resolved. +func (m *Manager) fallBackToUserAction(ua *UserAction) { + m.mu.Lock() + defer m.mu.Unlock() + if m.status != statusInProgress { + return + } + m.status = statusAwaitingUser + m.pending = ua + // Wake any callers joined on this flow so they receive the action, and clear + // done so complete() does not double-close it when run() later finishes. + if m.done != nil { + close(m.done) + m.done = nil + } +} + // complete records the flow result, installing a refreshing token source on // success, and wakes any joined callers. func (m *Manager) complete(tok *oauth2.Token, err error) { @@ -236,7 +274,9 @@ func (m *Manager) complete(tok *oauth2.Token, err error) { } } -// joinWait blocks until the running flow finishes or ctx is cancelled. +// joinWait blocks until the running flow finishes or ctx is cancelled. If the +// flow was promoted to the manual channel while waiting (its prompt could not be +// delivered), it returns that user action rather than an error. func (m *Manager) joinWait(ctx context.Context, done chan struct{}) (*Outcome, error) { select { case <-done: @@ -244,8 +284,12 @@ func (m *Manager) joinWait(ctx context.Context, done chan struct{}) (*Outcome, e return nil, nil } m.mu.Lock() + pending := m.pending err := m.lastErr m.mu.Unlock() + if pending != nil { + return &Outcome{UserAction: pending}, nil + } if err != nil { return nil, err } diff --git a/internal/oauth/manager_test.go b/internal/oauth/manager_test.go index fb8323246d..6f43c03ef9 100644 --- a/internal/oauth/manager_test.go +++ b/internal/oauth/manager_test.go @@ -117,6 +117,42 @@ func TestAuthenticateDeclinedPromptFails(t *testing.T) { assert.Empty(t, m.AccessToken()) } +func TestAuthenticateUndeliverablePromptFallsBack(t *testing.T) { + f := newFakeGitHub(t) + m := newManager(t, f) + m.openURL = func(string) error { return errors.New("no browser") } + + // The client advertised URL elicitation but delivering the prompt fails (a + // transport/protocol error, not a user decision). This must degrade to the + // manual instructions rather than aborting like a decline does. + prompter := &fakePrompter{ + urlCapable: true, + onURL: func(_ context.Context, _ Prompt) error { + return ErrPromptUnavailable + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + out, err := m.Authenticate(ctx, prompter) + require.NoError(t, err, "an undeliverable prompt must not abort the flow") + require.NotNil(t, out) + require.NotNil(t, out.UserAction, "an undeliverable prompt must fall back to a user action") + assert.NotEmpty(t, out.UserAction.URL) + assert.Contains(t, out.UserAction.Message, securityAdvisory) + + // A concurrent retry while awaiting the user returns the same fallback action. + out2, err := m.Authenticate(ctx, nil) + require.NoError(t, err) + require.NotNil(t, out2.UserAction) + assert.Equal(t, out.UserAction.URL, out2.UserAction.URL) + + // The background flow stayed alive: opening the URL out of band completes it. + require.NoError(t, browserGet(out.UserAction.URL)) + assert.Equal(t, "gho_access", waitForToken(t, m)) +} + func TestAuthenticateLastDitchUserAction(t *testing.T) { f := newFakeGitHub(t) m := newManager(t, f) diff --git a/internal/oauth/prompter.go b/internal/oauth/prompter.go index 9225589367..1d6aa6ec61 100644 --- a/internal/oauth/prompter.go +++ b/internal/oauth/prompter.go @@ -5,10 +5,18 @@ import ( "errors" ) -// ErrPromptDeclined is returned by a Prompter when the user cancels or declines -// the authorization prompt. +// ErrPromptDeclined is returned by a Prompter when the user actively cancels or +// declines the authorization prompt. It is a deliberate "no", so the flow stops +// rather than falling back to another channel. var ErrPromptDeclined = errors.New("authorization declined by user") +// ErrPromptUnavailable is returned by a Prompter when the prompt could not be +// delivered at all — for example the client advertised an elicitation capability +// but the request failed at the transport or protocol level. Unlike +// ErrPromptDeclined it reflects no user decision, so the flow falls back to a +// channel that needs no client capability instead of giving up. +var ErrPromptUnavailable = errors.New("authorization prompt could not be delivered") + // Prompt is the content shown to the user when asking them to authorize. type Prompt struct { // Message is a human-readable instruction. @@ -36,7 +44,8 @@ type Prompter interface { // until the user acknowledges, declines, or ctx is done. Returning nil means // the prompt was shown (not that authorization completed); the caller waits // for the OAuth flow itself to finish. It returns ErrPromptDeclined if the - // user declines or cancels. + // user declines or cancels, or ErrPromptUnavailable if the prompt could not + // be delivered. PromptURL(ctx context.Context, p Prompt) error // CanPromptForm reports whether the client supports form elicitation, used @@ -44,7 +53,8 @@ type Prompter interface { CanPromptForm() bool // PromptForm presents a textual acknowledgement prompt and blocks until the - // user responds. It returns ErrPromptDeclined if the user declines. + // user responds. It returns ErrPromptDeclined if the user declines, or + // ErrPromptUnavailable if the prompt could not be delivered. PromptForm(ctx context.Context, p Prompt) error } diff --git a/pkg/github/server.go b/pkg/github/server.go index 7ec5837c3a..627cc678b2 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -68,6 +68,11 @@ type MCPServerConfig struct { // This is used for PAT scope filtering where we can't issue scope challenges. TokenScopes []string + // TokenProvider, when non-nil, supplies the GitHub token for each API + // request instead of the static Token. It backs OAuth login, where the + // token is obtained lazily on first use and refreshed thereafter. + TokenProvider func() string + // Additional server options to apply ServerOptions []MCPServerOption } diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index ffa7669a9d..f7ffe67e6b 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -19,7 +19,13 @@ const ( OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource" ) -// SupportedScopes lists all OAuth scopes that may be required by MCP tools. +// SupportedScopes lists every OAuth scope that an MCP tool may require. It is the +// source of truth in two places: HTTP mode advertises it as scopes_supported in +// the protected-resource metadata, and stdio OAuth login requests it by default +// and then filters the exposed tools to the granted scopes. A tool whose required +// scope is absent here is therefore hidden under default OAuth even though a PAT +// carrying that scope would expose it, so keep this list in sync with tool scope +// requirements when scopes change. var SupportedScopes = []string{ "repo", "read:org", diff --git a/pkg/http/transport/bearer.go b/pkg/http/transport/bearer.go index 66922bbdaa..0c12ddfc91 100644 --- a/pkg/http/transport/bearer.go +++ b/pkg/http/transport/bearer.go @@ -11,11 +11,25 @@ import ( type BearerAuthTransport struct { Transport http.RoundTripper Token string + + // TokenProvider, when non-nil, supplies the bearer token for each request + // and takes precedence over Token. It backs OAuth, where the token is + // obtained after the client is built and is refreshed over the session's + // lifetime. It may return an empty string before authorization completes. + TokenProvider func() string } func (t *BearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { req = req.Clone(req.Context()) - req.Header.Set(headers.AuthorizationHeader, "Bearer "+t.Token) + token := t.Token + if t.TokenProvider != nil { + token = t.TokenProvider() + } + // Before OAuth authorization completes the token is empty; send an + // unauthenticated request rather than an empty "Bearer " header. + if token != "" { + req.Header.Set(headers.AuthorizationHeader, "Bearer "+token) + } // Check for GraphQL-Features in context and add header if present if features := ghcontext.GetGraphQLFeatures(req.Context()); len(features) > 0 { diff --git a/pkg/http/transport/bearer_test.go b/pkg/http/transport/bearer_test.go new file mode 100644 index 0000000000..76ef8686cd --- /dev/null +++ b/pkg/http/transport/bearer_test.go @@ -0,0 +1,164 @@ +package transport + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/headers" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBearerAuthTransport(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + token string + tokenProvider func() string + wantAuth string + }{ + { + name: "static token", + token: "static-token", + wantAuth: "Bearer static-token", + }, + { + name: "token provider takes precedence over static token", + token: "static-token", + tokenProvider: func() string { return "provided-token" }, + wantAuth: "Bearer provided-token", + }, + { + name: "token provider with empty static token", + tokenProvider: func() string { return "provided-token" }, + wantAuth: "Bearer provided-token", + }, + { + name: "token provider may return empty before authorization", + tokenProvider: func() string { return "" }, + wantAuth: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get(headers.AuthorizationHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + Token: tc.token, + TokenProvider: tc.tokenProvider, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, tc.wantAuth, gotAuth) + }) + } +} + +// TestBearerAuthTransport_TokenProviderResolvedPerRequest verifies that the +// token provider is consulted on every request, so a token that arrives (or is +// refreshed) after the transport is constructed takes effect without rebuilding +// the client. This is the property OAuth relies on. +func TestBearerAuthTransport_TokenProviderResolvedPerRequest(t *testing.T) { + t.Parallel() + + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get(headers.AuthorizationHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + current := "" + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + TokenProvider: func() string { return current }, + } + + do := func() { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + } + + do() + assert.Equal(t, "", gotAuth, "no auth header before authorization") + + current = "first-token" + do() + assert.Equal(t, "Bearer first-token", gotAuth, "token picked up once available") + + current = "refreshed-token" + do() + assert.Equal(t, "Bearer refreshed-token", gotAuth, "refreshed token picked up") +} + +func TestBearerAuthTransport_PassesGraphQLFeaturesHeader(t *testing.T) { + t.Parallel() + + var gotFeatures string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotFeatures = r.Header.Get(headers.GraphQLFeaturesHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + Token: "token", + } + + ctx := ghcontext.WithGraphQLFeatures(context.Background(), "feature1", "feature2") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, "feature1, feature2", gotFeatures) +} + +func TestBearerAuthTransport_DoesNotMutateOriginalRequest(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + Token: "token", + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Empty(t, req.Header.Get(headers.AuthorizationHeader), "original request must not be mutated") +}