diff --git a/auth.go b/auth.go index 15aa829..9c8df4e 100644 --- a/auth.go +++ b/auth.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "os" + "time" retry "github.com/appleboy/go-httpretry" "github.com/go-authgate/cli/tui" @@ -59,11 +60,41 @@ func authenticate( } // refreshAccessToken exchanges a refresh token for a new access token. +// +// It takes a cross-process advisory lock before the critical section so that +// concurrent CLI invocations cannot spend the same refresh token twice — +// which would cause one of them to receive invalid_grant on rotation servers. +// After acquiring the lock, it re-reads the store: if a peer process has +// already refreshed, the peer's fresh token is returned without a network +// call. func refreshAccessToken( ctx context.Context, cfg *AppConfig, - refreshToken string, + stale credstore.Token, ) (*credstore.Token, error) { + unlock, err := lockTokenStore(ctx, cfg.TokenFile, cfg.ClientID) + if err != nil { + return nil, fmt.Errorf("acquire refresh lock: %w", err) + } + defer unlock.Close() + + refreshToken := stale.RefreshToken + + // Peer check: another process may have refreshed while we waited for the lock. + if fresh, loadErr := cfg.Store.Load(cfg.ClientID); loadErr == nil { + peerRefreshed := fresh.AccessToken != "" && + time.Now().Before(fresh.ExpiresAt) && + (fresh.AccessToken != stale.AccessToken || + fresh.RefreshToken != stale.RefreshToken) + if peerRefreshed { + return &fresh, nil + } + // Peer may have rotated the refresh token without updating our view. + if fresh.RefreshToken != "" { + refreshToken = fresh.RefreshToken + } + } + ctx, cancel := context.WithTimeout(ctx, cfg.RefreshTokenTimeout) defer cancel() @@ -143,7 +174,7 @@ func makeAPICallWithAutoRefresh( if resp.StatusCode == http.StatusUnauthorized { ui.ShowStatus(tui.StatusUpdate{Event: tui.EventAccessTokenRejected}) - newStorage, err := refreshAccessToken(ctx, cfg, storage.RefreshToken) + newStorage, err := refreshAccessToken(ctx, cfg, *storage) if err != nil { if errors.Is(err, ErrRefreshTokenExpired) { return ErrRefreshTokenExpired diff --git a/config.go b/config.go index 5d6f51f..b916593 100644 --- a/config.go +++ b/config.go @@ -80,6 +80,7 @@ type AppConfig struct { Scope string ForceDevice bool TokenStoreMode string // "auto", "file", or "keyring" + TokenFile string // path used for the file backend and for the cross-process refresh lock RetryClient *retry.Client Store credstore.Store[credstore.Token] @@ -154,7 +155,7 @@ func loadStoreConfig() *AppConfig { cfg.ClientID = getConfig(flagClientID, "CLIENT_ID", "") cfg.TokenStoreMode = getConfig(flagTokenStore, "TOKEN_STORE", "auto") - tokenFile := getConfig(flagTokenFile, "TOKEN_FILE", ".authgate-tokens.json") + cfg.TokenFile = getConfig(flagTokenFile, "TOKEN_FILE", ".authgate-tokens.json") if cfg.ClientID == "" { fmt.Fprintln(os.Stderr, "Error: CLIENT_ID not set. Please provide it via:") @@ -166,7 +167,7 @@ func loadStoreConfig() *AppConfig { } var storeErr error - cfg.Store, storeErr = newTokenStore(cfg.TokenStoreMode, tokenFile, defaultKeyringService) + cfg.Store, storeErr = newTokenStore(cfg.TokenStoreMode, cfg.TokenFile, defaultKeyringService) if storeErr != nil { fmt.Fprintln(os.Stderr, storeErr) os.Exit(1) diff --git a/go.mod b/go.mod index 6af6b8a..3bd2bfa 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( charm.land/lipgloss/v2 v2.0.2 github.com/appleboy/go-httpretry v0.11.0 github.com/go-authgate/sdk-go v0.6.1 + github.com/gofrs/flock v0.13.0 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/mattn/go-isatty v0.0.20 diff --git a/go.sum b/go.sum index 97343fc..0b8a253 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,8 @@ github.com/go-authgate/sdk-go v0.6.1 h1:oQREINU63YckTRdJ+0VBmN6ewFSMXa0D862w8624 github.com/go-authgate/sdk-go v0.6.1/go.mod h1:55PLAPuu8GDK0omOwG6lx4c+9/T6dJwZd8kecUueLEk= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= +github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= +github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/lock.go b/lock.go new file mode 100644 index 0000000..e9aef8a --- /dev/null +++ b/lock.go @@ -0,0 +1,58 @@ +package main + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "time" + + "github.com/gofrs/flock" +) + +const lockRetryInterval = 100 * time.Millisecond + +// tokenStoreLock wraps a process-level advisory lock acquired via flock. +type tokenStoreLock struct { + fl *flock.Flock +} + +// Close releases the advisory lock. +func (l *tokenStoreLock) Close() error { + return l.fl.Unlock() +} + +// lockTokenStore acquires a cross-process advisory lock scoped to +// (tokenFile, clientID). It serialises the "load → refresh → save" +// critical section so concurrent CLI invocations cannot spend the same +// refresh token twice (which would yield invalid_grant on rotation servers). +// +// The lock sits next to the token file regardless of the active backend — +// keyring-backed runs also need the coordination because the race is in the +// refresh flow, not the storage layer. +func lockTokenStore(ctx context.Context, tokenFile, clientID string) (io.Closer, error) { + if tokenFile == "" { + return nil, errors.New("lock: tokenFile is empty") + } + if clientID == "" { + return nil, errors.New("lock: clientID is empty") + } + + dir := filepath.Dir(tokenFile) + if err := os.MkdirAll(dir, 0o700); err != nil { + return nil, fmt.Errorf("create lock directory %q: %w", dir, err) + } + + lockPath := filepath.Join(dir, filepath.Base(tokenFile)+"."+clientID+".lock") + fl := flock.New(lockPath) + locked, err := fl.TryLockContext(ctx, lockRetryInterval) + if err != nil { + return nil, fmt.Errorf("acquire lock %s: %w", lockPath, err) + } + if !locked { + return nil, fmt.Errorf("could not acquire lock %s", lockPath) + } + return &tokenStoreLock{fl: fl}, nil +} diff --git a/main.go b/main.go index 83110f4..f4cb690 100644 --- a/main.go +++ b/main.go @@ -92,7 +92,7 @@ func run(ctx context.Context, ui tui.Manager, cfg *AppConfig) int { flow = "cached" } else { ui.ShowStatus(tui.StatusUpdate{Event: tui.EventTokenExpired}) - newStorage, err := refreshAccessToken(ctx, cfg, existing.RefreshToken) + newStorage, err := refreshAccessToken(ctx, cfg, existing) if err != nil { ui.ShowStatus(tui.StatusUpdate{Event: tui.EventRefreshFailed, Err: err}) } else { diff --git a/main_test.go b/main_test.go index 752fc8c..9249f2e 100644 --- a/main_test.go +++ b/main_test.go @@ -25,14 +25,14 @@ func testConfig(t *testing.T) *AppConfig { t.Fatalf("failed to create retry client: %v", err) } serverURL := "http://localhost:8080" + tokenFile := filepath.Join(t.TempDir(), "tokens.json") return &AppConfig{ - ServerURL: serverURL, - ClientID: "test-client", - Scope: "email profile", - RetryClient: rc, - Store: credstore.NewTokenFileStore( - filepath.Join(t.TempDir(), "tokens.json"), - ), + ServerURL: serverURL, + ClientID: "test-client", + Scope: "email profile", + RetryClient: rc, + TokenFile: tokenFile, + Store: credstore.NewTokenFileStore(tokenFile), Endpoints: defaultEndpoints(serverURL), TokenExchangeTimeout: defaultTokenExchangeTimeout, TokenVerificationTimeout: defaultTokenVerificationTimeout, @@ -317,7 +317,11 @@ func TestRefreshAccessToken_RotationMode(t *testing.T) { cfg.Endpoints = defaultEndpoints(srv.URL) cfg.ClientID = "test-client-rotation" - storage, err := refreshAccessToken(context.Background(), cfg, tt.oldRefreshToken) + storage, err := refreshAccessToken( + context.Background(), + cfg, + credstore.Token{RefreshToken: tt.oldRefreshToken}, + ) if err != nil { t.Fatalf("refreshAccessToken() error: %v", err) } @@ -332,6 +336,57 @@ func TestRefreshAccessToken_RotationMode(t *testing.T) { } } +// TestRefreshAccessToken_PeerAlreadyRefreshed verifies that when a peer process +// has already refreshed and saved the tokens, refreshAccessToken returns the +// stored fresh token without making a network call. This guards against the +// refresh-token-rotation race that motivated the cross-process lock. +func TestRefreshAccessToken_PeerAlreadyRefreshed(t *testing.T) { + var serverCalled atomic.Bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + serverCalled.Store(true) + http.Error(w, "should not be called", http.StatusInternalServerError) + })) + defer srv.Close() + + cfg := testConfig(t) + cfg.ServerURL = srv.URL + cfg.Endpoints = defaultEndpoints(srv.URL) + cfg.ClientID = "peer-refresh-test" + + peerFresh := credstore.Token{ + AccessToken: "peer-fresh-access", + RefreshToken: "peer-fresh-refresh", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + ClientID: cfg.ClientID, + } + if err := cfg.Store.Save(cfg.ClientID, peerFresh); err != nil { + t.Fatalf("setup save: %v", err) + } + + stale := credstore.Token{ + AccessToken: "stale-access", + RefreshToken: "stale-refresh", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(-1 * time.Minute), + ClientID: cfg.ClientID, + } + + got, err := refreshAccessToken(context.Background(), cfg, stale) + if err != nil { + t.Fatalf("refreshAccessToken() error: %v", err) + } + if serverCalled.Load() { + t.Fatalf("network refresh was performed; peer-refresh shortcut was expected") + } + if got.AccessToken != peerFresh.AccessToken { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, peerFresh.AccessToken) + } + if got.RefreshToken != peerFresh.RefreshToken { + t.Errorf("RefreshToken = %q, want %q", got.RefreshToken, peerFresh.RefreshToken) + } +} + // ----------------------------------------------------------------------- // Device code request with retry // -----------------------------------------------------------------------