diff --git a/utils/src/main/java/software/amazon/awssdk/utils/cache/CachedSupplier.java b/utils/src/main/java/software/amazon/awssdk/utils/cache/CachedSupplier.java index e8ecc4d741d..38cd00033e2 100644 --- a/utils/src/main/java/software/amazon/awssdk/utils/cache/CachedSupplier.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/cache/CachedSupplier.java @@ -23,9 +23,9 @@ import java.util.Random; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Predicate; import java.util.function.Supplier; import software.amazon.awssdk.annotations.SdkProtectedApi; import software.amazon.awssdk.annotations.SdkTestInternalApi; @@ -54,6 +54,16 @@ public class CachedSupplier implements Supplier, SdkAutoCloseable { */ private static final Duration BLOCKING_REFRESH_MAX_WAIT = Duration.ofSeconds(5); + /** + * Minimum backoff duration when a refresh fails (inclusive). + */ + private static final Duration STATIC_STABILITY_BACKOFF_MIN = Duration.ofMinutes(5); + + /** + * Maximum backoff duration when a refresh fails (inclusive). + */ + private static final Duration STATIC_STABILITY_BACKOFF_MAX = Duration.ofMinutes(10); + /** * Used as a primitive form of rate limiting for the speed of our refreshes. This will make sure that the backing supplier has @@ -83,11 +93,6 @@ public class CachedSupplier implements Supplier, SdkAutoCloseable { */ private final Clock clock; - /** - * The number of consecutive failures encountered when updating a stale value. - */ - private final AtomicInteger consecutiveStaleRetrievalFailures = new AtomicInteger(0); - /** * The name to include with each log message, to differentiate caches. */ @@ -108,6 +113,12 @@ public class CachedSupplier implements Supplier, SdkAutoCloseable { */ private final Random jitterRandom = new Random(); + /** + * Predicate that determines whether an exception represents a non-recoverable refresh failure + * that should bypass static stability (i.e., be re-thrown immediately without extending expiration). + */ + private final Predicate cacheInvalidatingPredicate; + private CachedSupplier(Builder builder) { Validate.notNull(builder.supplier, "builder.supplier"); Validate.notNull(builder.jitterEnabled, "builder.jitterEnabled"); @@ -117,6 +128,7 @@ private CachedSupplier(Builder builder) { this.staleValueBehavior = Validate.notNull(builder.staleValueBehavior, "builder.staleValueBehavior"); this.clock = Validate.notNull(builder.clock, "builder.clock"); this.cachedValueName = Validate.notNull(builder.cachedValueName, "builder.cachedValueName"); + this.cacheInvalidatingPredicate = builder.cacheInvalidatingPredicate; } /** @@ -229,8 +241,6 @@ private void refreshCache() { * Perform necessary transformations of the successfully-fetched value based on the stale value behavior of this supplier. */ private RefreshResult handleFetchedSuccess(RefreshResult fetch) { - consecutiveStaleRetrievalFailures.set(0); - Instant now = clock.instant(); if (now.isBefore(fetch.staleTime())) { @@ -269,25 +279,57 @@ private RefreshResult handleFetchFailure(RuntimeException e) { Instant now = clock.instant(); if (!now.isBefore(currentCachedValue.staleTime())) { - int numFailures = consecutiveStaleRetrievalFailures.incrementAndGet(); - switch (staleValueBehavior) { case STRICT: throw e; case ALLOW: - Instant newStaleTime = jitterTime(now, Duration.ofMillis(1), maxStaleFailureJitter(numFailures)); - log.warn(() -> "(" + cachedValueName + ") Cached value expiration has been extended to " + - newStaleTime + " because calling the downstream service failed (consecutive failures: " + - numFailures + ").", e); + // Cache-invalidating errors bypass static stability + if (cacheInvalidatingPredicate != null && cacheInvalidatingPredicate.test(e)) { + throw e; + } + + // Uniform random backoff: 5-10 minutes + long backoffSeconds = STATIC_STABILITY_BACKOFF_MIN.getSeconds() + + jitterRandom.nextInt( + (int) (STATIC_STABILITY_BACKOFF_MAX.getSeconds() + - STATIC_STABILITY_BACKOFF_MIN.getSeconds() + 1)); + Instant extendedStaleTime = now.plusSeconds(backoffSeconds); + + log.warn(() -> "(" + cachedValueName + ") Credential refresh failed: " + e.getMessage() + + ". Extending cached credential expiration. A refresh of these credentials" + + " will be attempted again after " + backoffSeconds + " seconds.", e); return currentCachedValue.toBuilder() - .staleTime(newStaleTime) + .staleTime(extendedStaleTime) + .prefetchTime(extendedStaleTime) .build(); default: throw new IllegalStateException("Unknown stale-value-behavior: " + staleValueBehavior); } } + // Not yet stale — we're in the prefetch window. Handle failure based on mode. + if (staleValueBehavior == StaleValueBehavior.ALLOW) { + if (cacheInvalidatingPredicate != null && cacheInvalidatingPredicate.test(e)) { + throw e; + } + // During prefetch window failure: extend prefetchTime to suppress further attempts + long backoffSeconds = STATIC_STABILITY_BACKOFF_MIN.getSeconds() + + jitterRandom.nextInt( + (int) (STATIC_STABILITY_BACKOFF_MAX.getSeconds() + - STATIC_STABILITY_BACKOFF_MIN.getSeconds() + 1)); + Instant extendedPrefetchTime = now.plusSeconds(backoffSeconds); + + log.warn(() -> "(" + cachedValueName + ") Credential refresh failed: " + e.getMessage() + + ". Extending cached credential expiration. A refresh of these credentials" + + " will be attempted again after " + backoffSeconds + " seconds.", e); + + return currentCachedValue.toBuilder() + .staleTime(extendedPrefetchTime) + .prefetchTime(extendedPrefetchTime) + .build(); + } + return currentCachedValue; } @@ -333,6 +375,12 @@ private Duration maxPrefetchJitter(RefreshResult result) { return timeBetweenPrefetchAndStale; } + private Instant jitterTime(Instant time, Duration jitterStart, Duration jitterEnd) { + long jitterRange = jitterEnd.minus(jitterStart).toMillis(); + long jitterAmount = Math.abs(jitterRandom.nextLong() % jitterRange); + return time.plus(jitterStart).plusMillis(jitterAmount); + } + private Duration maxStaleFailureJitter(int numFailures) { // prevent cycling back through low values if (numFailures > 63) { @@ -350,12 +398,6 @@ protected Duration maxStaleFailureJitterTest(int numFailures) { return maxStaleFailureJitter(numFailures); } - private Instant jitterTime(Instant time, Duration jitterStart, Duration jitterEnd) { - long jitterRange = jitterEnd.minus(jitterStart).toMillis(); - long jitterAmount = Math.abs(jitterRandom.nextLong() % jitterRange); - return time.plus(jitterStart).plusMillis(jitterAmount); - } - /** * Free any resources consumed by the prefetch strategy this supplier is using. */ @@ -374,6 +416,7 @@ public static final class Builder { private StaleValueBehavior staleValueBehavior = StaleValueBehavior.STRICT; private Clock clock = Clock.systemUTC(); private String cachedValueName = "unknown"; + private Predicate cacheInvalidatingPredicate; private Builder(Supplier> supplier) { this.supplier = supplier; @@ -413,6 +456,23 @@ public Builder cachedValueName(String cachedValueName) { return this; } + /** + * Configure a predicate that determines whether an exception represents a non-recoverable refresh failure + * that should bypass static stability. When the predicate returns {@code true} for a given exception, + * the exception will be re-thrown immediately without extending the cached value's expiration. + * + *

This is used for errors where the credential source has definitively indicated that the current + * authentication state is invalid and requires user intervention (e.g., expired SSO tokens, + * changed user credentials).

+ * + *

By default, no exceptions are considered cache-invalidating (all failures trigger static stability + * backoff when {@link StaleValueBehavior#ALLOW} is configured).

+ */ + public Builder cacheInvalidatingPredicate(Predicate cacheInvalidatingPredicate) { + this.cacheInvalidatingPredicate = cacheInvalidatingPredicate; + return this; + } + /** * Configure the clock used for this cached supplier. Configurable for testing. */ @@ -488,8 +548,14 @@ public enum StaleValueBehavior { STRICT, /** - * Allow stale values to be returned from the cache. Value retrieval will never fail, as long as the cache has - * succeeded when calling the underlying supplier at least once. + * Allow stale values to be returned from the cache with static stability semantics. On refresh failure, + * extends the stale time by a uniformly random backoff between 5 and 10 minutes (300-600 seconds). + * + *

If a {@link Builder#cacheInvalidatingPredicate(Predicate)} is configured and returns {@code true} + * for the exception, it is re-thrown immediately without extending the stale time.

+ * + *

Value retrieval will never fail as long as the cache has succeeded at least once, + * unless the error is cache-invalidating.

*/ ALLOW } diff --git a/utils/src/test/java/software/amazon/awssdk/utils/cache/CachedSupplierTest.java b/utils/src/test/java/software/amazon/awssdk/utils/cache/CachedSupplierTest.java index 159e2d69b6e..5d0c9ec4381 100644 --- a/utils/src/test/java/software/amazon/awssdk/utils/cache/CachedSupplierTest.java +++ b/utils/src/test/java/software/amazon/awssdk/utils/cache/CachedSupplierTest.java @@ -364,25 +364,190 @@ public void throwIsHiddenIfValueIsStaleInAllowMode() throws InterruptedException } @Test - public void maxStaleFailureJitter_shouldNotReturnNegativeOrCycleLowValues() { - CachedSupplier supplier = CachedSupplier.builder(() -> RefreshResult.builder("v") - .staleTime(Instant.MAX) - .build()) - .build(); - - for (int i = 1; i <= 70; i++) { - Duration jitter = supplier.maxStaleFailureJitterTest(i); - assertThat(jitter) - .as("numFailures=%d: jitter must be positive", i) - .isPositive(); - - if (i > 64) { - assertThat(jitter) - .isEqualTo(Duration.ofSeconds(10)); + public void allowMode_returnsCachedValueOnNonCacheInvalidatingFailure() throws InterruptedException { + AdjustableClock clock = new AdjustableClock(); + MutableSupplier supplier = new MutableSupplier(); + try (CachedSupplier cachedSupplier = CachedSupplier.builder(supplier) + .staleValueBehavior(ALLOW) + .clock(clock) + .jitterEnabled(false) + .build()) { + Instant now = Instant.now(); + clock.time = now; + + // Initial successful fetch + supplier.set(RefreshResult.builder("cached-creds") + .staleTime(now.plusSeconds(60)) + .prefetchTime(now.plusSeconds(30)) + .build()); + assertThat(cachedSupplier.get()).isEqualTo("cached-creds"); + + // Advance past stale time + clock.time = now.plusSeconds(61); + supplier.set(new RuntimeException("service unavailable")); + + // Should return cached value instead of throwing + assertThat(cachedSupplier.get()).isEqualTo("cached-creds"); + } + } + + @Test + public void allowMode_cacheInvalidatingError_isRethrown() throws InterruptedException { + AdjustableClock clock = new AdjustableClock(); + MutableSupplier supplier = new MutableSupplier(); + try (CachedSupplier cachedSupplier = CachedSupplier.builder(supplier) + .staleValueBehavior(ALLOW) + .cacheInvalidatingPredicate( + e -> e instanceof CacheInvalidatingRuntimeException) + .clock(clock) + .jitterEnabled(false) + .build()) { + Instant now = Instant.now(); + clock.time = now; + + // Initial successful fetch + supplier.set(RefreshResult.builder("cached-creds") + .staleTime(now.plusSeconds(60)) + .prefetchTime(now.plusSeconds(30)) + .build()); + assertThat(cachedSupplier.get()).isEqualTo("cached-creds"); + + // Advance past stale time and throw cache-invalidating error + clock.time = now.plusSeconds(61); + CacheInvalidatingRuntimeException invalidatingError = + new CacheInvalidatingRuntimeException("token expired"); + supplier.set(invalidatingError); + + // Should re-throw even though cached value exists + assertThatThrownBy(cachedSupplier::get).isEqualTo(invalidatingError); + } + } + + @Test + public void allowMode_backoffIsInExpectedRange() throws InterruptedException { + AdjustableClock clock = new AdjustableClock(); + MutableSupplier supplier = new MutableSupplier(); + + // Run multiple iterations to verify backoff range + for (int i = 0; i < 50; i++) { + try (CachedSupplier cachedSupplier = CachedSupplier.builder(supplier) + .staleValueBehavior(ALLOW) + .clock(clock) + .jitterEnabled(false) + .build()) { + Instant now = Instant.parse("2024-01-01T00:00:00Z"); + clock.time = now; + + supplier.set(RefreshResult.builder("cached-creds") + .staleTime(now.plusSeconds(60)) + .prefetchTime(now.plusSeconds(30)) + .build()); + cachedSupplier.get(); + + // Advance past stale time and trigger failure + clock.time = now.plusSeconds(61); + supplier.set(new RuntimeException("service unavailable")); + cachedSupplier.get(); + + // Advance well past the extended time to test that the backoff was applied + // The extended stale time should be: now(61) + [300,600]s(backoff) + // So total offset from epoch: 61 + [300,600] = [361, 661] seconds from original now + Instant minExpectedStale = now.plusSeconds(61 + 300); + Instant maxExpectedStale = now.plusSeconds(61 + 600); + + // Advance just before the minimum backoff - should still return cached (not stale yet) + clock.time = minExpectedStale.minusSeconds(1); + supplier.set(RefreshResult.builder("new-creds") + .staleTime(Instant.MAX) + .prefetchTime(Instant.MAX) + .build()); + // Value not stale yet so should return cached + assertThat(cachedSupplier.get()).isEqualTo("cached-creds"); + + // Advance past maximum possible backoff - must be stale now and will refresh + clock.time = maxExpectedStale.plusSeconds(1); + assertThat(cachedSupplier.get()).isEqualTo("new-creds"); } } + } - supplier.close(); + @Test + public void allowMode_prefetchWindowFailure_extendsPrefetchTime() { + AdjustableClock clock = new AdjustableClock(); + MutableSupplier supplier = new MutableSupplier(); + try (CachedSupplier cachedSupplier = CachedSupplier.builder(supplier) + .staleValueBehavior(ALLOW) + .clock(clock) + .jitterEnabled(false) + .build()) { + Instant now = Instant.parse("2024-01-01T00:00:00Z"); + clock.time = now; + + // Initial successful fetch with prefetch in the future, stale much later + supplier.set(RefreshResult.builder("cached-creds") + .staleTime(now.plusSeconds(3600)) + .prefetchTime(now.plusSeconds(60)) + .build()); + assertThat(cachedSupplier.get()).isEqualTo("cached-creds"); + + // Advance past prefetch time but before stale time + clock.time = now.plusSeconds(61); + supplier.set(new RuntimeException("service unavailable")); + + // Should return cached value (not throw) and extend prefetch time + assertThat(cachedSupplier.get()).isEqualTo("cached-creds"); + + // Verify that a subsequent call shortly after does NOT attempt another refresh + // (because prefetchTime was extended) + clock.time = now.plusSeconds(62); + supplier.set(RefreshResult.builder("should-not-get-this") + .staleTime(Instant.MAX) + .prefetchTime(Instant.MAX) + .build()); + // The prefetchTime was extended far into the future, so this should still return cached + assertThat(cachedSupplier.get()).isEqualTo("cached-creds"); + } + } + + @Test + public void allowMode_prefetchWindowFailure_cacheInvalidatingError_isRethrown() { + AdjustableClock clock = new AdjustableClock(); + MutableSupplier supplier = new MutableSupplier(); + try (CachedSupplier cachedSupplier = CachedSupplier.builder(supplier) + .staleValueBehavior(ALLOW) + .cacheInvalidatingPredicate( + e -> e instanceof CacheInvalidatingRuntimeException) + .clock(clock) + .jitterEnabled(false) + .build()) { + Instant now = Instant.parse("2024-01-01T00:00:00Z"); + clock.time = now; + + // Initial successful fetch with prefetch in the future, stale much later + supplier.set(RefreshResult.builder("cached-creds") + .staleTime(now.plusSeconds(3600)) + .prefetchTime(now.plusSeconds(60)) + .build()); + assertThat(cachedSupplier.get()).isEqualTo("cached-creds"); + + // Advance past prefetch time but before stale time + clock.time = now.plusSeconds(61); + CacheInvalidatingRuntimeException invalidatingError = + new CacheInvalidatingRuntimeException("token expired"); + supplier.set(invalidatingError); + + // Should re-throw cache-invalidating error even in prefetch window + assertThatThrownBy(cachedSupplier::get).isEqualTo(invalidatingError); + } + } + + /** + * A RuntimeException that represents a cache-invalidating error for testing. + */ + private static class CacheInvalidatingRuntimeException extends RuntimeException { + CacheInvalidatingRuntimeException(String message) { + super(message); + } } @Test