diff --git a/internal/config/schema.go b/internal/config/schema.go index dfb3b12..3076680 100644 --- a/internal/config/schema.go +++ b/internal/config/schema.go @@ -49,6 +49,12 @@ type GitConfig struct { AutoPullInterval time.Duration `yaml:"auto_pull_interval,omitempty"` } +// OAuthConfig holds OAuth server configuration. +type OAuthConfig struct { + AccessTokenTTL time.Duration `yaml:"access_token_ttl,omitempty"` + RefreshTokenTTL time.Duration `yaml:"refresh_token_ttl,omitempty"` +} + // MCPConfig holds MCP server-related configuration for AI agent integration. type MCPConfig struct { Bind string `yaml:"bind,omitempty"` @@ -68,6 +74,7 @@ type MCPConfig struct { TLSCertFile string `yaml:"tls_cert_file,omitempty"` TLSKeyFile string `yaml:"tls_key_file,omitempty"` AllowInsecureBind bool `yaml:"allow_insecure_bind,omitempty"` + OAuth *OAuthConfig `yaml:"oauth,omitempty"` } // UpdateConfig holds update check-related configuration. @@ -137,6 +144,10 @@ func defaultMCPConfig() MCPConfig { ApprovalTimeout: 30 * time.Second, RateLimit: 60, MetricsAuthRequired: true, + OAuth: &OAuthConfig{ + AccessTokenTTL: 24 * time.Hour, + RefreshTokenTTL: 720 * time.Hour, + }, } } @@ -197,6 +208,13 @@ type fileGitConfig struct { CommitTemplate *string `yaml:"commit_template,omitempty"` } +// fileOAuthConfig is the file-based OAuth configuration with pointer fields +// for optional YAML unmarshaling. +type fileOAuthConfig struct { + AccessTokenTTL *time.Duration `yaml:"access_token_ttl,omitempty"` + RefreshTokenTTL *time.Duration `yaml:"refresh_token_ttl,omitempty"` +} + // fileMCPConfig is the file-based MCP configuration with pointer fields // for optional YAML unmarshaling. type fileMCPConfig struct { @@ -217,6 +235,7 @@ type fileMCPConfig struct { TLSCertFile *string `yaml:"tls_cert_file,omitempty"` TLSKeyFile *string `yaml:"tls_key_file,omitempty"` AllowInsecureBind *bool `yaml:"allow_insecure_bind,omitempty"` + OAuth *fileOAuthConfig `yaml:"oauth,omitempty"` } // fileUpdateConfig is the file-based update configuration with pointer fields @@ -363,6 +382,17 @@ func MergeFileMCPConfig(fileCfg *fileMCPConfig, defaults MCPConfig) MCPConfig { if fileCfg.AllowInsecureBind != nil { result.AllowInsecureBind = *fileCfg.AllowInsecureBind } + if fileCfg.OAuth != nil { + if result.OAuth == nil { + result.OAuth = &OAuthConfig{} + } + if fileCfg.OAuth.AccessTokenTTL != nil { + result.OAuth.AccessTokenTTL = *fileCfg.OAuth.AccessTokenTTL + } + if fileCfg.OAuth.RefreshTokenTTL != nil { + result.OAuth.RefreshTokenTTL = *fileCfg.OAuth.RefreshTokenTTL + } + } return result } diff --git a/internal/config/schema_oauth_test.go b/internal/config/schema_oauth_test.go new file mode 100644 index 0000000..7d4f977 --- /dev/null +++ b/internal/config/schema_oauth_test.go @@ -0,0 +1,74 @@ +package config + +import ( + "testing" + "time" +) + +func TestDefaultOAuthConfig(t *testing.T) { + cfg := defaultMCPConfig() + if cfg.OAuth == nil { + t.Fatal("OAuth config is nil in defaults") + } + if cfg.OAuth.AccessTokenTTL != 24*time.Hour { + t.Errorf("default AccessTokenTTL = %v, want 24h", cfg.OAuth.AccessTokenTTL) + } + if cfg.OAuth.RefreshTokenTTL != 720*time.Hour { + t.Errorf("default RefreshTokenTTL = %v, want 720h (30d)", cfg.OAuth.RefreshTokenTTL) + } +} + +func TestMergeFileOAuthConfig(t *testing.T) { + accessTTL := 10 * time.Second + refreshTTL := 30 * time.Second + + fileCfg := &fileMCPConfig{ + OAuth: &fileOAuthConfig{ + AccessTokenTTL: &accessTTL, + RefreshTokenTTL: &refreshTTL, + }, + } + + result := MergeFileMCPConfig(fileCfg, defaultMCPConfig()) + + if result.OAuth == nil { + t.Fatal("OAuth config is nil after merge") + } + if result.OAuth.AccessTokenTTL != accessTTL { + t.Errorf("AccessTokenTTL = %v, want %v", result.OAuth.AccessTokenTTL, accessTTL) + } + if result.OAuth.RefreshTokenTTL != refreshTTL { + t.Errorf("RefreshTokenTTL = %v, want %v", result.OAuth.RefreshTokenTTL, refreshTTL) + } +} + +func TestMergeFileOAuthConfig_PartialOverride(t *testing.T) { + refreshTTL := 100 * time.Hour + + fileCfg := &fileMCPConfig{ + OAuth: &fileOAuthConfig{ + RefreshTokenTTL: &refreshTTL, + }, + } + + result := MergeFileMCPConfig(fileCfg, defaultMCPConfig()) + + if result.OAuth.AccessTokenTTL != 24*time.Hour { + t.Errorf("AccessTokenTTL = %v, want default 24h", result.OAuth.AccessTokenTTL) + } + if result.OAuth.RefreshTokenTTL != 100*time.Hour { + t.Errorf("RefreshTokenTTL = %v, want 100h", result.OAuth.RefreshTokenTTL) + } +} + +func TestMergeFileOAuthConfig_NilOAuth(t *testing.T) { + fileCfg := &fileMCPConfig{} + result := MergeFileMCPConfig(fileCfg, defaultMCPConfig()) + + if result.OAuth == nil { + t.Fatal("OAuth config should not be nil after merge with nil file cfg") + } + if result.OAuth.AccessTokenTTL != 24*time.Hour { + t.Errorf("AccessTokenTTL = %v, want default 24h", result.OAuth.AccessTokenTTL) + } +} diff --git a/internal/mcp/serverbootstrap/http.go b/internal/mcp/serverbootstrap/http.go index 4e3c26e..db7eacb 100644 --- a/internal/mcp/serverbootstrap/http.go +++ b/internal/mcp/serverbootstrap/http.go @@ -190,7 +190,17 @@ func RunHTTPServerOnListener(ctx context.Context, listener net.Listener, v *vaul mux.HandleFunc("GET /mcp/oauth/authorize", oauthAuthorizeHandler.ServeHTTP) // Token endpoint uses the scoped token registry instead of the legacy bearer token. - oauthTokenHandler := mcp.OriginValidationMiddleware(addr, handleOAuthToken(oauthStore, registry)) + accessTokenTTL := 24 * time.Hour + refreshTokenTTL := 720 * time.Hour + if v != nil && v.Config != nil && v.Config.MCP != nil && v.Config.MCP.OAuth != nil { + if v.Config.MCP.OAuth.AccessTokenTTL > 0 { + accessTokenTTL = v.Config.MCP.OAuth.AccessTokenTTL + } + if v.Config.MCP.OAuth.RefreshTokenTTL > 0 { + refreshTokenTTL = v.Config.MCP.OAuth.RefreshTokenTTL + } + } + oauthTokenHandler := mcp.OriginValidationMiddleware(addr, handleOAuthToken(oauthStore, registry, accessTokenTTL, refreshTokenTTL)) mux.HandleFunc("POST /mcp/oauth/token", oauthTokenHandler.ServeHTTP) const maxRequestBodySize = 1 * 1024 * 1024 diff --git a/internal/mcp/serverbootstrap/oauth.go b/internal/mcp/serverbootstrap/oauth.go index b47fd8c..4e4dc9f 100644 --- a/internal/mcp/serverbootstrap/oauth.go +++ b/internal/mcp/serverbootstrap/oauth.go @@ -125,7 +125,7 @@ func handleOAuthRegister(clientStore *oauthClientStore) http.HandlerFunc { "client_id_issued_at": time.Now().Unix(), "client_secret_expires_at": 0, "token_endpoint_auth_method": "none", - "grant_types": []string{"authorization_code"}, + "grant_types": []string{"authorization_code", "refresh_token"}, "response_types": []string{"code"}, "redirect_uris": req.RedirectURIs, }) @@ -222,55 +222,84 @@ func handleOAuthAuthorize(store *oauthCodeStore, clientStore *oauthClientStore) } // handleOAuthToken implements the authorization code grant (RFC 6749 §4.1.3) -// with PKCE verification (RFC 7636). On success it mints a fresh scoped MCP -// token via the TokenRegistry instead of returning the global legacy bearer -// token. The scoped token has a 24-hour TTL. -func handleOAuthToken(store *oauthCodeStore, registry *mcp.TokenRegistry) http.HandlerFunc { +// with PKCE verification (RFC 7636) and refresh token support (RFC 6749 §6). +// On success it mints a fresh scoped MCP token via the TokenRegistry instead +// of returning the global legacy bearer token. +func handleOAuthToken(store *oauthCodeStore, registry *mcp.TokenRegistry, accessTokenTTL, refreshTokenTTL time.Duration) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_request"}) return } - if r.FormValue("grant_type") != "authorization_code" { - writeJSON(w, http.StatusBadRequest, map[string]string{"error": "unsupported_grant_type"}) - return - } if registry == nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "server_error"}) return } - pending, ok := store.take(r.FormValue("code")) - if !ok || time.Now().After(pending.expiresAt) { - writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant"}) - return - } - if !verifyS256(r.FormValue("code_verifier"), pending.codeChallenge) { - writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant"}) - return - } + grantType := r.FormValue("grant_type") + + switch grantType { + case "authorization_code": + pending, ok := store.take(r.FormValue("code")) + if !ok || time.Now().After(pending.expiresAt) { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant"}) + return + } + if !verifyS256(r.FormValue("code_verifier"), pending.codeChallenge) { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant"}) + return + } + + label := fmt.Sprintf("oauth-%s", pending.clientID[:8]) + tok, rawToken, rawRefresh, err := registry.CreateWithRefresh( + label, []string{"*"}, "oauth", accessTokenTTL, refreshTokenTTL, + ) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "server_error"}) + return + } + + expiresIn := 0 + if tok.ExpiresAt != nil { + expiresIn = int(time.Until(*tok.ExpiresAt).Seconds()) + } + + writeJSON(w, http.StatusOK, map[string]any{ + "access_token": rawToken, + "token_type": "Bearer", + "expires_in": expiresIn, + "refresh_token": rawRefresh, + }) - // Mint a fresh scoped token instead of returning the legacy bearer token. - // This ensures every OAuth-issued token is independently revocable and - // auditable — the global legacy token is never exposed to OAuth clients. - label := fmt.Sprintf("oauth-%s", pending.clientID[:8]) - tok, rawToken, err := registry.Create(label, []string{"*"}, "oauth", 24*time.Hour) - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "server_error"}) - return - } + case "refresh_token": + rawRefresh := r.FormValue("refresh_token") + if rawRefresh == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_request", "error_description": "refresh_token is required"}) + return + } + + newTok, rawAccess, rawRefresh, err := registry.RotateViaRefreshToken(rawRefresh) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid_grant", "error_description": "invalid or expired refresh token"}) + return + } + + expiresIn := 0 + if newTok.ExpiresAt != nil { + expiresIn = int(time.Until(*newTok.ExpiresAt).Seconds()) + } + + writeJSON(w, http.StatusOK, map[string]any{ + "access_token": rawAccess, + "token_type": "Bearer", + "expires_in": expiresIn, + "refresh_token": rawRefresh, + }) - expiresIn := 0 - if tok.ExpiresAt != nil { - expiresIn = int(time.Until(*tok.ExpiresAt).Seconds()) + default: + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "unsupported_grant_type"}) } - - writeJSON(w, http.StatusOK, map[string]any{ - "access_token": rawToken, - "token_type": "Bearer", - "expires_in": expiresIn, - }) } } diff --git a/internal/mcp/serverbootstrap/oauth_refresh_test.go b/internal/mcp/serverbootstrap/oauth_refresh_test.go new file mode 100644 index 0000000..9e592c9 --- /dev/null +++ b/internal/mcp/serverbootstrap/oauth_refresh_test.go @@ -0,0 +1,335 @@ +package serverbootstrap + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/danieljustus/OpenPass/internal/mcp" +) + +type mockResponseWriter struct { + header http.Header + body bytes.Buffer + status int +} + +func (m *mockResponseWriter) Header() http.Header { return m.header } +func (m *mockResponseWriter) Write(b []byte) (int, error) { return m.body.Write(b) } +func (m *mockResponseWriter) WriteHeader(s int) { m.status = s } + +func challengeForVerifier(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} + +func TestOAuthRefreshToken_FullFlow(t *testing.T) { + dir := t.TempDir() + regPath := dir + "/tokens.json" + reg := mcp.NewTokenRegistry(regPath) + + accessTTL := 10 * time.Minute + refreshTTL := 30 * time.Minute + + store := newOAuthCodeStore() + handler := handleOAuthToken(store, reg, accessTTL, refreshTTL) + + code := "test-auth-code-123" + store.put(code, &pendingCode{ + clientID: "test-client", + redirectURI: "http://localhost:9999/callback", + codeChallenge: challengeForVerifier("test-verifier"), + challengeMethod: "S256", + expiresAt: time.Now().Add(5 * time.Minute), + }) + + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "code_verifier": {"test-verifier"}, + } + req, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := &mockResponseWriter{header: http.Header{}} + + handler.ServeHTTP(w, req) + + if w.status != http.StatusOK { + t.Fatalf("token exchange status = %d, want %d; body=%s", w.status, http.StatusOK, w.body.String()) + } + + var tokenResp map[string]any + if err := json.Unmarshal(w.body.Bytes(), &tokenResp); err != nil { + t.Fatalf("decode token response: %v", err) + } + if tokenResp["access_token"] == nil { + t.Fatal("access_token missing") + } + if tokenResp["refresh_token"] == nil { + t.Fatal("refresh_token missing") + } + + rawAccess := tokenResp["access_token"].(string) + rawRefresh := tokenResp["refresh_token"].(string) + + w2 := &mockResponseWriter{header: http.Header{}} + form2 := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {rawRefresh}, + } + req2, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(form2.Encode())) + req2.Header.Set("Content-Type", "application/x-www-form-urlencoded") + handler.ServeHTTP(w2, req2) + + if w2.status != http.StatusOK { + t.Fatalf("refresh token exchange status = %d, want %d; body=%s", w2.status, http.StatusOK, w2.body.String()) + } + + var refreshResp map[string]any + if err := json.Unmarshal(w2.body.Bytes(), &refreshResp); err != nil { + t.Fatalf("decode refresh response: %v", err) + } + if refreshResp["access_token"] == nil { + t.Fatal("new access_token missing after refresh") + } + if refreshResp["refresh_token"] == nil { + t.Fatal("new refresh_token missing after refresh") + } + + newAccess := refreshResp["access_token"].(string) + newRefresh := refreshResp["refresh_token"].(string) + + if newAccess == rawAccess { + t.Error("access token was not rotated") + } + if newRefresh == rawRefresh { + t.Error("refresh token was not rotated") + } + + oldHash := sha256HexRaw(rawAccess) + _, ok := reg.Get(oldHash) + if ok { + t.Error("old access token should be revoked after refresh") + } + + w3 := &mockResponseWriter{header: http.Header{}} + form3 := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {rawRefresh}, + } + req3, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(form3.Encode())) + req3.Header.Set("Content-Type", "application/x-www-form-urlencoded") + handler.ServeHTTP(w3, req3) + if w3.status != http.StatusBadRequest { + t.Errorf("old refresh token status = %d, want %d", w3.status, http.StatusBadRequest) + } + + var errResp map[string]any + if err := json.Unmarshal(w3.body.Bytes(), &errResp); err != nil { + t.Fatalf("decode error response: %v", err) + } + if errResp["error"] != "invalid_grant" { + t.Errorf("error = %q, want invalid_grant", errResp["error"]) + } + + w4 := &mockResponseWriter{header: http.Header{}} + form4 := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {newRefresh}, + } + req4, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(form4.Encode())) + req4.Header.Set("Content-Type", "application/x-www-form-urlencoded") + handler.ServeHTTP(w4, req4) + if w4.status != http.StatusOK { + t.Fatalf("new refresh token status = %d, want %d", w4.status, http.StatusOK) + } +} + +func TestOAuthRefreshToken_ExpiredRefreshDenied(t *testing.T) { + dir := t.TempDir() + regPath := dir + "/tokens.json" + reg := mcp.NewTokenRegistry(regPath) + + accessTTL := 1 * time.Hour + refreshTTL := 1 * time.Millisecond + + store := newOAuthCodeStore() + handler := handleOAuthToken(store, reg, accessTTL, refreshTTL) + + code := "test-code-expired" + store.put(code, &pendingCode{ + clientID: "test-client", + redirectURI: "http://localhost:9999/callback", + codeChallenge: challengeForVerifier("v"), + challengeMethod: "S256", + expiresAt: time.Now().Add(5 * time.Minute), + }) + + form := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "code_verifier": {"v"}, + } + req, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := &mockResponseWriter{header: http.Header{}} + handler.ServeHTTP(w, req) + + if w.status != http.StatusOK { + t.Fatalf("token exchange status = %d, want %d", w.status, http.StatusOK) + } + + var tokenResp map[string]any + if err := json.Unmarshal(w.body.Bytes(), &tokenResp); err != nil { + t.Fatalf("decode token response: %v", err) + } + rawRefresh := tokenResp["refresh_token"].(string) + + time.Sleep(2 * time.Millisecond) + + w2 := &mockResponseWriter{header: http.Header{}} + form2 := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {rawRefresh}, + } + req2, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(form2.Encode())) + req2.Header.Set("Content-Type", "application/x-www-form-urlencoded") + handler.ServeHTTP(w2, req2) + + if w2.status != http.StatusBadRequest { + t.Fatalf("expired refresh token status = %d, want %d; body=%s", w2.status, http.StatusBadRequest, w2.body.String()) + } + + var errResp map[string]any + if err := json.Unmarshal(w2.body.Bytes(), &errResp); err != nil { + t.Fatalf("decode error response: %v", err) + } + if errResp["error"] != "invalid_grant" { + t.Errorf("error = %q, want invalid_grant", errResp["error"]) + } +} + +func TestOAuthRefreshToken_RegisterResponseIncludesRefresh(t *testing.T) { + clientStore := newOAuthClientStore() + handler := handleOAuthRegister(clientStore) + + reqBody := `{"redirect_uris": ["http://localhost:3000/callback"]}` + req, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := &mockResponseWriter{header: http.Header{}} + + handler.ServeHTTP(w, req) + + if w.status != http.StatusCreated { + t.Fatalf("register status = %d, want %d", w.status, http.StatusCreated) + } + + var regResp map[string]any + if err := json.Unmarshal(w.body.Bytes(), ®Resp); err != nil { + t.Fatalf("decode register response: %v", err) + } + + grantTypes, ok := regResp["grant_types"].([]any) + if !ok { + t.Fatal("grant_types missing from registration response") + } + + hasRefresh := false + hasAuthCode := false + for _, gt := range grantTypes { + switch gt { + case "authorization_code": + hasAuthCode = true + case "refresh_token": + hasRefresh = true + } + } + if !hasAuthCode { + t.Error("authorization_code missing from grant_types") + } + if !hasRefresh { + t.Error("refresh_token missing from grant_types") + } +} + +func TestOAuthRefreshToken_WellKnownIncludesRefresh(t *testing.T) { + handler := handleOAuthAuthorizationServer("127.0.0.1", 9999) + + req, _ := http.NewRequest(http.MethodGet, "/", nil) + w := &mockResponseWriter{header: http.Header{}} + + handler.ServeHTTP(w, req) + + var body map[string]any + if err := json.Unmarshal(w.body.Bytes(), &body); err != nil { + t.Fatalf("decode well-known response: %v", err) + } + + grantTypes, ok := body["grant_types_supported"].([]any) + if !ok { + t.Fatal("grant_types_supported missing") + } + + hasRefresh := false + for _, gt := range grantTypes { + if gt == "refresh_token" { + hasRefresh = true + break + } + } + if !hasRefresh { + t.Error("refresh_token missing from grant_types_supported") + } +} + +func TestOAuthRefreshToken_UnsupportedGrantType(t *testing.T) { + reg := mcp.NewTokenRegistry("") + handler := handleOAuthToken(nil, reg, 24*time.Hour, 720*time.Hour) + + form := url.Values{"grant_type": {"unsupported"}} + req, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := &mockResponseWriter{header: http.Header{}} + + handler.ServeHTTP(w, req) + + if w.status != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", w.status, http.StatusBadRequest) + } + + var errResp map[string]any + if err := json.Unmarshal(w.body.Bytes(), &errResp); err != nil { + t.Fatalf("decode error: %v", err) + } + if errResp["error"] != "unsupported_grant_type" { + t.Errorf("error = %q, want unsupported_grant_type", errResp["error"]) + } +} + +func TestOAuthRefreshToken_MissingRefreshToken(t *testing.T) { + reg := mcp.NewTokenRegistry("") + handler := handleOAuthToken(nil, reg, 24*time.Hour, 720*time.Hour) + + form := url.Values{"grant_type": {"refresh_token"}} + req, _ := http.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := &mockResponseWriter{header: http.Header{}} + + handler.ServeHTTP(w, req) + + if w.status != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", w.status, http.StatusBadRequest) + } +} + +func sha256HexRaw(s string) string { + h := sha256.Sum256([]byte(s)) + return hex.EncodeToString(h[:]) +} diff --git a/internal/mcp/serverbootstrap/wellknown.go b/internal/mcp/serverbootstrap/wellknown.go index 72cc823..f28781f 100644 --- a/internal/mcp/serverbootstrap/wellknown.go +++ b/internal/mcp/serverbootstrap/wellknown.go @@ -46,7 +46,7 @@ func handleOAuthAuthorizationServer(bind string, port int) http.HandlerFunc { "response_types_supported": []string{"code"}, "code_challenge_methods_supported": []string{"S256"}, "token_endpoint_auth_methods_supported": []string{"none"}, - "grant_types_supported": []string{"authorization_code"}, + "grant_types_supported": []string{"authorization_code", "refresh_token"}, }) } } diff --git a/internal/mcp/token.go b/internal/mcp/token.go index 77812cb..639405c 100644 --- a/internal/mcp/token.go +++ b/internal/mcp/token.go @@ -27,33 +27,37 @@ type TokenRegistryFile struct { // TokenRegistryEntry is a single entry in the on-disk token registry. type TokenRegistryEntry struct { - ID string `json:"id"` - Label string `json:"label,omitempty"` - Hash string `json:"hash"` - Prefix string `json:"prefix"` - AllowedTools []string `json:"allowed_tools"` - AgentName string `json:"agent_name,omitempty"` - CreatedAt time.Time `json:"created_at"` - ExpiresAt *time.Time `json:"expires_at,omitempty"` - LastUsedAt *time.Time `json:"last_used_at,omitempty"` - Revoked bool `json:"revoked"` - RevokedAt *time.Time `json:"revoked_at,omitempty"` + ID string `json:"id"` + Label string `json:"label,omitempty"` + Hash string `json:"hash"` + Prefix string `json:"prefix"` + AllowedTools []string `json:"allowed_tools"` + AgentName string `json:"agent_name,omitempty"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + LastUsedAt *time.Time `json:"last_used_at,omitempty"` + Revoked bool `json:"revoked"` + RevokedAt *time.Time `json:"revoked_at,omitempty"` + RefreshTokenHash string `json:"refresh_token_hash,omitempty"` + RefreshExpiresAt *time.Time `json:"refresh_expires_at,omitempty"` } // ScopedToken is the in-memory representation of a scoped token with its // associated metadata. It is safe for concurrent access. type ScopedToken struct { - ID string `json:"id"` - Label string `json:"label,omitempty"` - Hash string `json:"hash"` - Prefix string `json:"prefix"` - AllowedTools []string `json:"allowed_tools"` - AgentName string `json:"agent_name,omitempty"` - CreatedAt time.Time `json:"created_at"` - ExpiresAt *time.Time `json:"expires_at,omitempty"` - LastUsedAt *time.Time `json:"last_used_at,omitempty"` - Revoked bool `json:"revoked"` - RevokedAt *time.Time `json:"revoked_at,omitempty"` + ID string `json:"id"` + Label string `json:"label,omitempty"` + Hash string `json:"hash"` + Prefix string `json:"prefix"` + AllowedTools []string `json:"allowed_tools"` + AgentName string `json:"agent_name,omitempty"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + LastUsedAt *time.Time `json:"last_used_at,omitempty"` + Revoked bool `json:"revoked"` + RevokedAt *time.Time `json:"revoked_at,omitempty"` + RefreshTokenHash string `json:"refresh_token_hash,omitempty"` + RefreshExpiresAt *time.Time `json:"refresh_expires_at,omitempty"` mu sync.Mutex } @@ -71,6 +75,18 @@ func (t *ScopedToken) IsExpired() bool { return time.Now().After(*t.ExpiresAt) } +// IsRefreshExpired returns true if the refresh token has a defined expiration +// that has already passed. A nil RefreshExpiresAt means no expiration. +func (t *ScopedToken) IsRefreshExpired() bool { + if t == nil { + return true + } + if t.RefreshExpiresAt == nil { + return false + } + return time.Now().After(*t.RefreshExpiresAt) +} + // IsToolAllowed returns true when the given tool name is permitted by this // token. A wildcard "*" in the AllowedTools list grants access to every tool. func (t *ScopedToken) IsToolAllowed(toolName string) bool { @@ -106,17 +122,19 @@ func (t *ScopedToken) toEntry() TokenRegistryEntry { return TokenRegistryEntry{} } return TokenRegistryEntry{ - ID: t.ID, - Label: t.Label, - Hash: t.Hash, - Prefix: t.Prefix, - AllowedTools: t.AllowedTools, - AgentName: t.AgentName, - CreatedAt: t.CreatedAt, - ExpiresAt: t.ExpiresAt, - LastUsedAt: t.LastUsedAt, - Revoked: t.Revoked, - RevokedAt: t.RevokedAt, + ID: t.ID, + Label: t.Label, + Hash: t.Hash, + Prefix: t.Prefix, + AllowedTools: t.AllowedTools, + AgentName: t.AgentName, + CreatedAt: t.CreatedAt, + ExpiresAt: t.ExpiresAt, + LastUsedAt: t.LastUsedAt, + Revoked: t.Revoked, + RevokedAt: t.RevokedAt, + RefreshTokenHash: t.RefreshTokenHash, + RefreshExpiresAt: t.RefreshExpiresAt, } } @@ -127,17 +145,19 @@ func entryToScopedToken(e TokenRegistryEntry) *ScopedToken { allowed = []string{} } return &ScopedToken{ - ID: e.ID, - Label: e.Label, - Hash: e.Hash, - Prefix: e.Prefix, - AllowedTools: allowed, - AgentName: e.AgentName, - CreatedAt: e.CreatedAt, - ExpiresAt: e.ExpiresAt, - LastUsedAt: e.LastUsedAt, - Revoked: e.Revoked, - RevokedAt: e.RevokedAt, + ID: e.ID, + Label: e.Label, + Hash: e.Hash, + Prefix: e.Prefix, + AllowedTools: allowed, + AgentName: e.AgentName, + CreatedAt: e.CreatedAt, + ExpiresAt: e.ExpiresAt, + LastUsedAt: e.LastUsedAt, + Revoked: e.Revoked, + RevokedAt: e.RevokedAt, + RefreshTokenHash: e.RefreshTokenHash, + RefreshExpiresAt: e.RefreshExpiresAt, } } @@ -193,7 +213,7 @@ func (r *TokenRegistry) Load() error { func (r *TokenRegistry) Save() error { r.mu.Lock() file := TokenRegistryFile{ - Version: 1, + Version: 2, Tokens: make(map[string]TokenRegistryEntry, len(r.entries)), } for _, t := range r.entries { @@ -302,6 +322,135 @@ func (r *TokenRegistry) Revoke(id string) bool { return false } +// getByRefreshTokenHash looks up a token by its refresh token hash. Returns +// nil if not found, revoked, or expired. +func (r *TokenRegistry) getByRefreshTokenHash(refreshHash string) *ScopedToken { + for _, t := range r.entries { + if t.RefreshTokenHash == refreshHash && !t.Revoked && !t.IsExpired() { + return t + } + } + return nil +} + +// CreateWithRefresh generates a new access+refresh token pair, stores both +// hashes in the registry, and persists to disk. It returns the token metadata, +// the cleartext access token, and the cleartext refresh token. +func (r *TokenRegistry) CreateWithRefresh(label string, allowedTools []string, agentName string, accessTTL, refreshTTL time.Duration) (*ScopedToken, string, string, error) { + accessBuf := make([]byte, 32) + if _, err := randReader.Read(accessBuf); err != nil { + return nil, "", "", fmt.Errorf("generate access token: %w", err) + } + rawAccess := hex.EncodeToString(accessBuf) + + refreshBuf := make([]byte, 32) + if _, err := randReader.Read(refreshBuf); err != nil { + return nil, "", "", fmt.Errorf("generate refresh token: %w", err) + } + rawRefresh := hex.EncodeToString(refreshBuf) + + id := generateTokenID() + accessHash := sha256Hex(rawAccess) + refreshHash := sha256Hex(rawRefresh) + prefix := rawAccess[:4] + + var expiresAt *time.Time + if accessTTL > 0 { + t := time.Now().UTC().Add(accessTTL) + expiresAt = &t + } + var refreshExpiresAt *time.Time + if refreshTTL > 0 { + t := time.Now().UTC().Add(refreshTTL) + refreshExpiresAt = &t + } + + if allowedTools == nil { + allowedTools = []string{} + } + + createdAt := time.Now().UTC() + t := &ScopedToken{ + ID: id, + Label: label, + Hash: accessHash, + Prefix: prefix, + AllowedTools: allowedTools, + AgentName: agentName, + CreatedAt: createdAt, + ExpiresAt: expiresAt, + RefreshTokenHash: refreshHash, + RefreshExpiresAt: refreshExpiresAt, + } + + r.mu.Lock() + r.entries[accessHash] = t + r.mu.Unlock() + + if err := r.Save(); err != nil { + r.mu.Lock() + delete(r.entries, accessHash) + r.mu.Unlock() + return nil, "", "", err + } + + return t, rawAccess, rawRefresh, nil +} + +// RotateViaRefreshToken revokes the old access token associated with the given +// refresh token and creates a new access+refresh token pair. The old refresh +// token is also invalidated (single-use pattern). Returns the new token +// metadata, raw access token, and raw refresh token. +func (r *TokenRegistry) RotateViaRefreshToken(rawRefreshToken string) (*ScopedToken, string, string, error) { + refreshHash := sha256Hex(rawRefreshToken) + + r.mu.Lock() + + oldEntry := r.getByRefreshTokenHash(refreshHash) + if oldEntry == nil { + r.mu.Unlock() + return nil, "", "", fmt.Errorf("invalid refresh token") + } + + if oldEntry.IsRefreshExpired() { + r.mu.Unlock() + return nil, "", "", fmt.Errorf("invalid refresh token: expired") + } + + oldEntry.Revoked = true + now := time.Now().UTC() + oldEntry.RevokedAt = &now + + r.mu.Unlock() + + var accessTTL, refreshTTL time.Duration + if oldEntry.ExpiresAt != nil { + accessTTL = time.Until(*oldEntry.ExpiresAt) + if accessTTL < 0 { + accessTTL = 0 + } + } + if oldEntry.RefreshExpiresAt != nil { + refreshTTL = time.Until(*oldEntry.RefreshExpiresAt) + if refreshTTL < 0 { + refreshTTL = 0 + } + } + + newTok, rawAccess, rawRefresh, err := r.CreateWithRefresh( + oldEntry.Label, + oldEntry.AllowedTools, + oldEntry.AgentName, + accessTTL, + refreshTTL, + ) + if err != nil { + return nil, "", "", fmt.Errorf("rotate via refresh: %w", err) + } + + return newTok, rawAccess, rawRefresh, nil +} + // List returns a snapshot of all tokens currently in the registry. Expired // tokens are excluded and removed; revoked tokens are included for the audit // trail. diff --git a/internal/mcp/token_test.go b/internal/mcp/token_test.go index b4d2f54..a0a770f 100644 --- a/internal/mcp/token_test.go +++ b/internal/mcp/token_test.go @@ -278,8 +278,8 @@ func TestTokenRegistry_CreateLoadSave_Roundtrip(t *testing.T) { if err := json.Unmarshal(data, &file); err != nil { t.Fatalf("unmarshal registry: %v", err) } - if file.Version != 1 { - t.Errorf("version = %d, want 1", file.Version) + if file.Version < 1 || file.Version > 2 { + t.Errorf("version = %d, want 1 or 2", file.Version) } if len(file.Tokens) != 1 { t.Fatalf("token count = %d, want 1", len(file.Tokens)) @@ -809,8 +809,8 @@ func TestTokenRegistry_Load_Tokens(t *testing.T) { if err := json.Unmarshal(data, &file); err != nil { t.Fatalf("unmarshal: %v", err) } - if file.Version != 1 { - t.Errorf("Version = %d, want 1", file.Version) + if file.Version < 1 || file.Version > 2 { + t.Errorf("Version = %d, want 1 or 2", file.Version) } }