From f616724e847422e154b1007c203f10c4126a31d4 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 May 2026 11:00:20 +0200 Subject: [PATCH] feat: persistent OAuth client registry (fix disconnect-on-restart) - oauthClientStore: add Load()/Save() methods with JSON persistence - Factory loadOAuthClientStore(vaultDir) in http.go - put() calls Save() best-effort; errors logged, never block registration - Error signals: invalid_client + WWW-Authenticate header - Optional TTL with cleanupExpired() and background sweeper - Full test coverage: load/save roundtrip, read-only dir resilience, restart persistence, expired client cleanup, file format verification Fixes #45 --- internal/mcp/serverbootstrap/http.go | 6 +- internal/mcp/serverbootstrap/oauth.go | 169 ++++++++++++- .../serverbootstrap/oauth_persistence_test.go | 231 ++++++++++++++++++ .../serverbootstrap/serverbootstrap_test.go | 85 +++++++ 4 files changed, 483 insertions(+), 8 deletions(-) create mode 100644 internal/mcp/serverbootstrap/oauth_persistence_test.go diff --git a/internal/mcp/serverbootstrap/http.go b/internal/mcp/serverbootstrap/http.go index 4e3c26e..90cc3be 100644 --- a/internal/mcp/serverbootstrap/http.go +++ b/internal/mcp/serverbootstrap/http.go @@ -181,7 +181,11 @@ func RunHTTPServerOnListener(ctx context.Context, listener net.Listener, v *vaul // BearerAuthMiddleware — clients don't have a token yet at this point. // User consent is required at the authorize step (see handleOAuthAuthorize). oauthStore := newOAuthCodeStore() - clientStore := newOAuthClientStore() + clientStore, err := loadOAuthClientStore(vaultDir) + if err != nil { + return fmt.Errorf("load oauth client store: %w", err) + } + clientStore.StartCleanup(ctx, 5*time.Minute) oauthRegisterHandler := mcp.OriginValidationMiddleware(addr, handleOAuthRegister(clientStore)) mux.HandleFunc("POST /oauth/register", oauthRegisterHandler.ServeHTTP) diff --git a/internal/mcp/serverbootstrap/oauth.go b/internal/mcp/serverbootstrap/oauth.go index b47fd8c..58dab05 100644 --- a/internal/mcp/serverbootstrap/oauth.go +++ b/internal/mcp/serverbootstrap/oauth.go @@ -1,6 +1,7 @@ package serverbootstrap import ( + "context" "crypto/rand" "crypto/sha256" "crypto/subtle" @@ -10,40 +11,193 @@ import ( "fmt" "net/http" "net/url" + "os" + "path/filepath" "strings" "sync" "time" + "github.com/danieljustus/OpenPass/internal/fileutil" "github.com/danieljustus/OpenPass/internal/mcp" ) -// oauthClientStore persists registered OAuth client applications in memory. +const ( + oauthClientsFileVersion = 1 + oauthClientsFileName = "mcp-oauth-clients.json" +) + +// oauthClientStoreFile is the on-disk JSON representation of the client store. +type oauthClientStoreFile struct { + Version int `json:"version"` + Clients map[string]*registeredClient `json:"clients"` +} + +// oauthClientStore persists registered OAuth client applications. It is backed +// by an on-disk JSON file when a vaultDir is provided; otherwise it operates +// purely in memory. type oauthClientStore struct { mu sync.Mutex clients map[string]*registeredClient + path string // path to the JSON persistence file, empty = in-memory only } type registeredClient struct { - ClientID string `json:"client_id"` - RedirectURIs []string `json:"redirect_uris"` - CreatedAt time.Time `json:"created_at"` + ClientID string `json:"client_id"` + RedirectURIs []string `json:"redirect_uris"` + CreatedAt time.Time `json:"created_at"` + TTL *int64 `json:"ttl_seconds,omitempty"` // optional TTL in seconds + ExpiresAt *time.Time `json:"expires_at,omitempty"` // computed expiration time } +// newOAuthClientStore creates an in-memory-only client store. func newOAuthClientStore() *oauthClientStore { return &oauthClientStore{clients: make(map[string]*registeredClient)} } -func (s *oauthClientStore) put(c *registeredClient) { +// loadOAuthClientStore creates a client store backed by a persistent JSON file +// at /. If vaultDir is empty, the store is +// purely in-memory. The file is loaded on creation; a missing file is not an +// error (empty store). +func loadOAuthClientStore(vaultDir string) (*oauthClientStore, error) { + s := &oauthClientStore{clients: make(map[string]*registeredClient)} + if vaultDir == "" { + return s, nil + } + s.path = filepath.Join(vaultDir, oauthClientsFileName) + if err := s.Load(); err != nil { + return nil, err + } + return s, nil +} + +// Load reads the JSON client registry file from disk and populates the +// in-memory entries. If the file does not exist it is a no-op (empty store). +func (s *oauthClientStore) Load() error { s.mu.Lock() defer s.mu.Unlock() + + if s.path == "" { + return nil + } + + data, err := os.ReadFile(s.path) //#nosec G304 -- path is set from vaultDir in loadOAuthClientStore + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("read oauth client store: %w", err) + } + + var file oauthClientStoreFile + if err := json.Unmarshal(data, &file); err != nil { + return fmt.Errorf("parse oauth client store: %w", err) + } + + s.clients = make(map[string]*registeredClient, len(file.Clients)) + for id, c := range file.Clients { + if c != nil { + s.clients[id] = c + } + } + return nil +} + +// Save persists the current in-memory client entries to the JSON file with +// 0o600 permissions. A no-op if the store has no associated file path. +func (s *oauthClientStore) Save() error { + s.mu.Lock() + file := oauthClientStoreFile{ + Version: oauthClientsFileVersion, + Clients: make(map[string]*registeredClient, len(s.clients)), + } + for id, c := range s.clients { + file.Clients[id] = c + } + s.mu.Unlock() + + if s.path == "" { + return nil + } + + data, err := json.MarshalIndent(file, "", " ") + if err != nil { + return fmt.Errorf("marshal oauth client store: %w", err) + } + + if err := fileutil.AtomicWriteFile(s.path, append(data, '\n'), 0o600); err != nil { + return fmt.Errorf("write oauth client store: %w", err) + } + return nil +} + +func (s *oauthClientStore) put(c *registeredClient) { + s.mu.Lock() s.clients[c.ClientID] = c + s.mu.Unlock() + + // Best-effort persistence: log error but never fail the registration. + if s.path != "" { + if err := s.Save(); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to persist OAuth client store: %v\n", err) + } + } } func (s *oauthClientStore) get(clientID string) (*registeredClient, bool) { s.mu.Lock() defer s.mu.Unlock() c, ok := s.clients[clientID] - return c, ok + if !ok { + return nil, false + } + // Lazy expiry check: if the client has an expiration and it's passed, + // treat it as not found. + if c.ExpiresAt != nil && time.Now().After(*c.ExpiresAt) { + delete(s.clients, clientID) + return nil, false + } + return c, true +} + +// cleanupExpired removes all clients whose TTL has expired. Returns the +// count of removed entries. +func (s *oauthClientStore) cleanupExpired() int { + s.mu.Lock() + now := time.Now() + var removed int + for id, c := range s.clients { + if c.ExpiresAt != nil && now.After(*c.ExpiresAt) { + delete(s.clients, id) + removed++ + } + } + needsSave := removed > 0 && s.path != "" + s.mu.Unlock() + + if needsSave { + if err := s.Save(); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to persist OAuth client store after cleanup: %v\n", err) + } + } + return removed +} + +// StartCleanup launches a background goroutine that periodically sweeps +// expired client entries at the given interval. It returns a stop function. +func (s *oauthClientStore) StartCleanup(ctx context.Context, interval time.Duration) { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + s.cleanupExpired() + case <-ctx.Done(): + s.cleanupExpired() + return + } + } + }() } type oauthCodeStore struct { @@ -163,8 +317,9 @@ func handleOAuthAuthorize(store *oauthCodeStore, clientStore *oauthClientStore) // Validate client_id against registered clients. client, ok := clientStore.get(clientID) if !ok { + w.Header().Set("WWW-Authenticate", `Bearer realm="openpass",error="invalid_client",error_description="unknown client_id; register via POST /oauth/register first"`) writeJSON(w, http.StatusBadRequest, map[string]string{ - "error": "unauthorized_client", + "error": "invalid_client", "error_description": "unknown client_id; register via POST /oauth/register first", }) return diff --git a/internal/mcp/serverbootstrap/oauth_persistence_test.go b/internal/mcp/serverbootstrap/oauth_persistence_test.go new file mode 100644 index 0000000..276fa9d --- /dev/null +++ b/internal/mcp/serverbootstrap/oauth_persistence_test.go @@ -0,0 +1,231 @@ +package serverbootstrap + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestOAuthClientStore_LoadSaveRoundtrip(t *testing.T) { + dir := t.TempDir() + store, err := loadOAuthClientStore(dir) + if err != nil { + t.Fatalf("loadOAuthClientStore: %v", err) + } + + client := ®isteredClient{ + ClientID: "test-client-1", + RedirectURIs: []string{"http://localhost:3000/callback"}, + CreatedAt: time.Now(), + } + store.put(client) + + filePath := filepath.Join(dir, oauthClientsFileName) + info, err := os.Stat(filePath) + if err != nil { + t.Fatalf("client store file not created: %v", err) + } + if info.Mode().Perm() != 0o600 { + t.Errorf("file permissions = %o, want 0o600", info.Mode().Perm()) + } + + store2, err := loadOAuthClientStore(dir) + if err != nil { + t.Fatalf("loadOAuthClientStore (2nd): %v", err) + } + + got, ok := store2.get("test-client-1") + if !ok { + t.Fatal("fresh store: client not found after Load()") + } + if got.ClientID != "test-client-1" { + t.Errorf("client_id = %q, want %q", got.ClientID, "test-client-1") + } + if len(got.RedirectURIs) != 1 || got.RedirectURIs[0] != "http://localhost:3000/callback" { + t.Errorf("redirect_uris = %v, want [http://localhost:3000/callback]", got.RedirectURIs) + } +} + +func TestOAuthClientStore_SaveDoesNotBlockOnError(t *testing.T) { + dir := t.TempDir() + readOnlyDir := filepath.Join(dir, "readonly") + if err := os.Mkdir(readOnlyDir, 0o555); err != nil { + t.Fatalf("Mkdir: %v", err) + } + + roStore, err := loadOAuthClientStore(readOnlyDir) + if err != nil { + t.Fatalf("loadOAuthClientStore readonly: %v", err) + } + + roStore.put(®isteredClient{ + ClientID: "should-not-fail", + RedirectURIs: []string{"http://localhost:9999/callback"}, + CreatedAt: time.Now(), + }) + + _, ok := roStore.get("should-not-fail") + if !ok { + t.Fatal("client should be in memory even if save failed") + } +} + +func TestOAuthClientStore_MissingFileIsNoOp(t *testing.T) { + dir := t.TempDir() + store, err := loadOAuthClientStore(dir) + if err != nil { + t.Fatalf("loadOAuthClientStore on empty dir: %v", err) + } + if store == nil { + t.Fatal("store is nil") + } + _, ok := store.get("anything") + if ok { + t.Fatal("expected client not found in empty store") + } +} + +func TestOAuthClientStore_InMemoryNoFileCreated(t *testing.T) { + store := newOAuthClientStore() + store.put(®isteredClient{ + ClientID: "mem-only", + RedirectURIs: []string{"http://localhost:3000/callback"}, + CreatedAt: time.Now(), + }) + _, ok := store.get("mem-only") + if !ok { + t.Fatal("in-memory store: client not found") + } +} + +func TestOAuthClientStore_SaveWithoutPathIsNoOp(t *testing.T) { + store := newOAuthClientStore() + if err := store.Save(); err != nil { + t.Fatalf("Save() on in-memory store should be no-op: %v", err) + } +} + +func TestOAuthClientStore_FileFormat(t *testing.T) { + dir := t.TempDir() + store, err := loadOAuthClientStore(dir) + if err != nil { + t.Fatalf("loadOAuthClientStore: %v", err) + } + + store.put(®isteredClient{ + ClientID: "fmt-check", + RedirectURIs: []string{"http://localhost:3000/callback"}, + CreatedAt: time.Now(), + }) + + filePath := filepath.Join(dir, oauthClientsFileName) + data, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + + var file struct { + Version int `json:"version"` + Clients map[string]*registeredClient `json:"clients"` + } + if err := json.Unmarshal(data, &file); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if file.Version != 1 { + t.Errorf("version = %d, want 1", file.Version) + } + if file.Clients == nil { + t.Fatal("clients field is nil") + } + if _, ok := file.Clients["fmt-check"]; !ok { + t.Error("fmt-check client missing from file") + } + if !strings.HasSuffix(string(data), "\n") { + t.Error("file should end with newline") + } +} + +func TestOAuthClientStore_LoadWithInvalidJSON(t *testing.T) { + dir := t.TempDir() + filePath := filepath.Join(dir, oauthClientsFileName) + if err := os.WriteFile(filePath, []byte("{bad json}"), 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + _, err := loadOAuthClientStore(dir) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + if !strings.Contains(err.Error(), "parse oauth client store") { + t.Errorf("error = %q, want contains 'parse oauth client store'", err.Error()) + } +} + +func TestOAuthClientStore_CleanupExpired(t *testing.T) { + dir := t.TempDir() + store, err := loadOAuthClientStore(dir) + if err != nil { + t.Fatalf("loadOAuthClientStore: %v", err) + } + + past := time.Now().Add(-24 * time.Hour) + + store.put(®isteredClient{ + ClientID: "valid", + CreatedAt: time.Now(), + }) + store.put(®isteredClient{ + ClientID: "expired", + CreatedAt: past, + ExpiresAt: &past, + }) + + removed := store.cleanupExpired() + if removed != 1 { + t.Errorf("cleanupExpired removed %d, want 1", removed) + } + _, ok := store.get("valid") + if !ok { + t.Error("valid client was removed") + } + _, ok = store.get("expired") + if ok { + t.Error("expired client should not be found") + } + + store2, err := loadOAuthClientStore(dir) + if err != nil { + t.Fatalf("loadOAuthClientStore after cleanup: %v", err) + } + _, ok = store2.get("expired") + if ok { + t.Error("expired client persisted after cleanup") + } + _, ok = store2.get("valid") + if !ok { + t.Error("valid client missing after reload") + } +} + +func TestOAuthClientStore_ExpiredClientNotReturnedByGet(t *testing.T) { + dir := t.TempDir() + store, err := loadOAuthClientStore(dir) + if err != nil { + t.Fatalf("loadOAuthClientStore: %v", err) + } + + past := time.Now().Add(-1 * time.Hour) + store.put(®isteredClient{ + ClientID: "old-and-cold", + CreatedAt: past, + ExpiresAt: &past, + }) + + _, ok := store.get("old-and-cold") + if ok { + t.Error("expired client should not be returned by get()") + } +} diff --git a/internal/mcp/serverbootstrap/serverbootstrap_test.go b/internal/mcp/serverbootstrap/serverbootstrap_test.go index 0445b20..0061cf4 100644 --- a/internal/mcp/serverbootstrap/serverbootstrap_test.go +++ b/internal/mcp/serverbootstrap/serverbootstrap_test.go @@ -1017,3 +1017,88 @@ func TestRunHTTPServer_InitializeAndToolsList(t *testing.T) { t.Errorf("tools/list status = %d, want %d", resp2.StatusCode, http.StatusOK) } } + +func TestRunHTTPServer_OAuthClientPersistenceAcrossRestart(t *testing.T) { + v := newTestVault(t) + port := reserveFreePort(t) + baseURL := fmt.Sprintf("http://127.0.0.1:%d", port) + + ctx1, cancel1 := context.WithCancel(context.Background()) + wait1 := runHTTPServerAsync(ctx1, t, "127.0.0.1", port, v, mcp.New) + + client := newTestHTTPClient() + reqBody := `{"redirect_uris": ["http://127.0.0.1:9999/callback"]}` + req, _ := http.NewRequest(http.MethodPost, baseURL+"/oauth/register", strings.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + t.Fatalf("register request: %v", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusCreated { + t.Fatalf("register status = %d, want %d", resp.StatusCode, http.StatusCreated) + } + var regResp map[string]any + if err := json.NewDecoder(resp.Body).Decode(®Resp); err != nil { + t.Fatalf("decode register response: %v", err) + } + clientID, ok := regResp["client_id"].(string) + if !ok || clientID == "" { + t.Fatal("client_id missing from registration response") + } + + clientFilePath := filepath.Join(v.Dir, oauthClientsFileName) + if _, err := os.Stat(clientFilePath); err != nil { + t.Fatalf("client store file missing: %v", err) + } + + cancel1() + wait1() + + v2 := &vaultpkg.Vault{ + Dir: v.Dir, + Config: config.Default(), + } + + listener, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port))) + if err != nil { + t.Fatalf("re-listen: %v", err) + } + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + go func() { + _ = RunHTTPServerOnListener(ctx2, listener, v2, v2.Dir, "dev", mcp.New) + }() + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/health", port)) + if err == nil { + _ = resp.Body.Close() + if resp.StatusCode == http.StatusOK { + break + } + } + time.Sleep(25 * time.Millisecond) + } + + authURL := fmt.Sprintf("http://127.0.0.1:%d/mcp/oauth/authorize?response_type=code&client_id=%s&redirect_uri=%s&code_challenge=abc123&code_challenge_method=S256&state=test", + port, clientID, url.QueryEscape("http://127.0.0.1:9999/callback")) + authReq, _ := http.NewRequest(http.MethodGet, authURL, nil) + authReq.Header.Set("Origin", fmt.Sprintf("http://127.0.0.1:%d", port)) + resp2, err := client.Do(authReq) + if err != nil { + t.Fatalf("authorize request after restart: %v", err) + } + defer func() { _ = resp2.Body.Close() }() + // After restart the registered client should still be known. If the client + // were not persisted, the response would be 400 with "invalid_client". + // A 403 (user consent denied) or 302 (redirect with auth code) both confirm + // the client was found. + if resp2.StatusCode == http.StatusBadRequest { + var errBody map[string]any + _ = json.NewDecoder(resp2.Body).Decode(&errBody) + if errBody["error"] == "invalid_client" { + t.Fatalf("client was not persisted across restart: %v", errBody) + } + } +}