Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Predicate;
import java.util.function.Supplier;
import software.amazon.awssdk.annotations.SdkProtectedApi;
import software.amazon.awssdk.annotations.SdkTestInternalApi;
Expand Down Expand Up @@ -54,6 +54,16 @@ public class CachedSupplier<T> implements Supplier<T>, SdkAutoCloseable {
*/
private static final Duration BLOCKING_REFRESH_MAX_WAIT = Duration.ofSeconds(5);

/**
* Minimum backoff duration when a refresh fails (inclusive).
*/
private static final Duration STATIC_STABILITY_BACKOFF_MIN = Duration.ofMinutes(5);

/**
* Maximum backoff duration when a refresh fails (inclusive).
*/
private static final Duration STATIC_STABILITY_BACKOFF_MAX = Duration.ofMinutes(10);


/**
* Used as a primitive form of rate limiting for the speed of our refreshes. This will make sure that the backing supplier has
Expand Down Expand Up @@ -83,11 +93,6 @@ public class CachedSupplier<T> implements Supplier<T>, SdkAutoCloseable {
*/
private final Clock clock;

/**
* The number of consecutive failures encountered when updating a stale value.
*/
private final AtomicInteger consecutiveStaleRetrievalFailures = new AtomicInteger(0);

/**
* The name to include with each log message, to differentiate caches.
*/
Expand All @@ -108,6 +113,12 @@ public class CachedSupplier<T> implements Supplier<T>, SdkAutoCloseable {
*/
private final Random jitterRandom = new Random();

/**
* Predicate that determines whether an exception represents a non-recoverable refresh failure
* that should bypass static stability (i.e., be re-thrown immediately without extending expiration).
*/
private final Predicate<RuntimeException> cacheInvalidatingPredicate;

private CachedSupplier(Builder<T> builder) {
Validate.notNull(builder.supplier, "builder.supplier");
Validate.notNull(builder.jitterEnabled, "builder.jitterEnabled");
Expand All @@ -117,6 +128,7 @@ private CachedSupplier(Builder<T> builder) {
this.staleValueBehavior = Validate.notNull(builder.staleValueBehavior, "builder.staleValueBehavior");
this.clock = Validate.notNull(builder.clock, "builder.clock");
this.cachedValueName = Validate.notNull(builder.cachedValueName, "builder.cachedValueName");
this.cacheInvalidatingPredicate = builder.cacheInvalidatingPredicate;
}

/**
Expand Down Expand Up @@ -229,8 +241,6 @@ private void refreshCache() {
* Perform necessary transformations of the successfully-fetched value based on the stale value behavior of this supplier.
*/
private RefreshResult<T> handleFetchedSuccess(RefreshResult<T> fetch) {
consecutiveStaleRetrievalFailures.set(0);

Instant now = clock.instant();

if (now.isBefore(fetch.staleTime())) {
Expand Down Expand Up @@ -269,25 +279,57 @@ private RefreshResult<T> handleFetchFailure(RuntimeException e) {

Instant now = clock.instant();
if (!now.isBefore(currentCachedValue.staleTime())) {
int numFailures = consecutiveStaleRetrievalFailures.incrementAndGet();

switch (staleValueBehavior) {
case STRICT:
throw e;
case ALLOW:
Instant newStaleTime = jitterTime(now, Duration.ofMillis(1), maxStaleFailureJitter(numFailures));
log.warn(() -> "(" + cachedValueName + ") Cached value expiration has been extended to " +
newStaleTime + " because calling the downstream service failed (consecutive failures: " +
numFailures + ").", e);
// Cache-invalidating errors bypass static stability
if (cacheInvalidatingPredicate != null && cacheInvalidatingPredicate.test(e)) {
throw e;
}

// Uniform random backoff: 5-10 minutes
long backoffSeconds = STATIC_STABILITY_BACKOFF_MIN.getSeconds()
+ jitterRandom.nextInt(
(int) (STATIC_STABILITY_BACKOFF_MAX.getSeconds()
- STATIC_STABILITY_BACKOFF_MIN.getSeconds() + 1));
Instant extendedStaleTime = now.plusSeconds(backoffSeconds);

log.warn(() -> "(" + cachedValueName + ") Credential refresh failed: " + e.getMessage()
+ ". Extending cached credential expiration. A refresh of these credentials"
+ " will be attempted again after " + backoffSeconds + " seconds.", e);

return currentCachedValue.toBuilder()
.staleTime(newStaleTime)
.staleTime(extendedStaleTime)
.prefetchTime(extendedStaleTime)
.build();
default:
throw new IllegalStateException("Unknown stale-value-behavior: " + staleValueBehavior);
}
}

// Not yet stale — we're in the prefetch window. Handle failure based on mode.
if (staleValueBehavior == StaleValueBehavior.ALLOW) {
if (cacheInvalidatingPredicate != null && cacheInvalidatingPredicate.test(e)) {
throw e;
}
// During prefetch window failure: extend prefetchTime to suppress further attempts
long backoffSeconds = STATIC_STABILITY_BACKOFF_MIN.getSeconds()
+ jitterRandom.nextInt(
(int) (STATIC_STABILITY_BACKOFF_MAX.getSeconds()
- STATIC_STABILITY_BACKOFF_MIN.getSeconds() + 1));
Instant extendedPrefetchTime = now.plusSeconds(backoffSeconds);

log.warn(() -> "(" + cachedValueName + ") Credential refresh failed: " + e.getMessage()
+ ". Extending cached credential expiration. A refresh of these credentials"
+ " will be attempted again after " + backoffSeconds + " seconds.", e);

return currentCachedValue.toBuilder()
.staleTime(extendedPrefetchTime)
.prefetchTime(extendedPrefetchTime)
.build();
}

return currentCachedValue;
}

Expand Down Expand Up @@ -333,6 +375,12 @@ private Duration maxPrefetchJitter(RefreshResult<T> result) {
return timeBetweenPrefetchAndStale;
}

private Instant jitterTime(Instant time, Duration jitterStart, Duration jitterEnd) {
long jitterRange = jitterEnd.minus(jitterStart).toMillis();
long jitterAmount = Math.abs(jitterRandom.nextLong() % jitterRange);
return time.plus(jitterStart).plusMillis(jitterAmount);
}

private Duration maxStaleFailureJitter(int numFailures) {
// prevent cycling back through low values
if (numFailures > 63) {
Expand All @@ -350,12 +398,6 @@ protected Duration maxStaleFailureJitterTest(int numFailures) {
return maxStaleFailureJitter(numFailures);
}

private Instant jitterTime(Instant time, Duration jitterStart, Duration jitterEnd) {
long jitterRange = jitterEnd.minus(jitterStart).toMillis();
long jitterAmount = Math.abs(jitterRandom.nextLong() % jitterRange);
return time.plus(jitterStart).plusMillis(jitterAmount);
}

/**
* Free any resources consumed by the prefetch strategy this supplier is using.
*/
Expand All @@ -374,6 +416,7 @@ public static final class Builder<T> {
private StaleValueBehavior staleValueBehavior = StaleValueBehavior.STRICT;
private Clock clock = Clock.systemUTC();
private String cachedValueName = "unknown";
private Predicate<RuntimeException> cacheInvalidatingPredicate;

private Builder(Supplier<RefreshResult<T>> supplier) {
this.supplier = supplier;
Expand Down Expand Up @@ -413,6 +456,23 @@ public Builder<T> cachedValueName(String cachedValueName) {
return this;
}

/**
* Configure a predicate that determines whether an exception represents a non-recoverable refresh failure
* that should bypass static stability. When the predicate returns {@code true} for a given exception,
* the exception will be re-thrown immediately without extending the cached value's expiration.
*
* <p>This is used for errors where the credential source has definitively indicated that the current
* authentication state is invalid and requires user intervention (e.g., expired SSO tokens,
* changed user credentials).</p>
*
* <p>By default, no exceptions are considered cache-invalidating (all failures trigger static stability
* backoff when {@link StaleValueBehavior#ALLOW} is configured).</p>
*/
public Builder<T> cacheInvalidatingPredicate(Predicate<RuntimeException> cacheInvalidatingPredicate) {
this.cacheInvalidatingPredicate = cacheInvalidatingPredicate;
return this;
}

/**
* Configure the clock used for this cached supplier. Configurable for testing.
*/
Expand Down Expand Up @@ -488,8 +548,14 @@ public enum StaleValueBehavior {
STRICT,

/**
* Allow stale values to be returned from the cache. Value retrieval will never fail, as long as the cache has
* succeeded when calling the underlying supplier at least once.
* Allow stale values to be returned from the cache with static stability semantics. On refresh failure,
* extends the stale time by a uniformly random backoff between 5 and 10 minutes (300-600 seconds).
*
* <p>If a {@link Builder#cacheInvalidatingPredicate(Predicate)} is configured and returns {@code true}
* for the exception, it is re-thrown immediately without extending the stale time.</p>
*
* <p>Value retrieval will never fail as long as the cache has succeeded at least once,
* unless the error is cache-invalidating.</p>
*/
ALLOW
}
Expand Down
Loading
Loading