diff --git a/access-token-management/src/AccessTokenManagement/Internal/ClientCredentialsCacheDurationStore.cs b/access-token-management/src/AccessTokenManagement/Internal/ClientCredentialsCacheDurationStore.cs new file mode 100644 index 000000000..9f0960f15 --- /dev/null +++ b/access-token-management/src/AccessTokenManagement/Internal/ClientCredentialsCacheDurationStore.cs @@ -0,0 +1,53 @@ +// Copyright (c) Duende Software. All rights reserved. +// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. + +using System.Collections.Concurrent; +using Microsoft.Extensions.Options; + +namespace Duende.AccessTokenManagement.Internal; + +/// +/// Store for cache duration auto-tuning state. +/// +internal sealed class ClientCredentialsCacheDurationStore( + IOptions options, + TimeProvider time) +{ + private readonly ClientCredentialsTokenManagementOptions _options = options.Value; + private readonly ConcurrentDictionary _cacheDurations = new(); + + /// + /// Gets the cache duration for a given cache key, or returns the default value if not found. + /// + public TimeSpan GetExpiration(ClientCredentialsCacheKey cacheKey) + { + var cacheExpiration = _options.UseCacheAutoTuning + ? _cacheDurations.GetValueOrDefault(cacheKey, _options.DefaultCacheLifetime) + : _options.DefaultCacheLifetime; + return cacheExpiration; + } + + /// + /// Sets the cache duration for a given cache key. + /// + public TimeSpan SetExpiration(ClientCredentialsCacheKey cacheKey, DateTimeOffset expiration) + { + if (!_options.UseCacheAutoTuning + || expiration == DateTimeOffset.MaxValue) + { + return _options.DefaultCacheLifetime; + } + + // Calculate how long this access token should be valid in the cache. + // Note, the expiration time was just calculated by adding time.GetUTcNow() to the token lifetime. + // so for now it's safe to subtract this time from the expiration time. + + var calculated = expiration + - time.GetUtcNow() + - TimeSpan.FromSeconds(_options.CacheLifetimeBuffer); + + _cacheDurations[cacheKey] = calculated; + + return calculated; + } +} diff --git a/access-token-management/src/AccessTokenManagement/Internal/ClientCredentialsTokenManager.cs b/access-token-management/src/AccessTokenManagement/Internal/ClientCredentialsTokenManager.cs index 9c732e022..7a0a877ac 100644 --- a/access-token-management/src/AccessTokenManagement/Internal/ClientCredentialsTokenManager.cs +++ b/access-token-management/src/AccessTokenManagement/Internal/ClientCredentialsTokenManager.cs @@ -1,7 +1,6 @@ // Copyright (c) Duende Software. All rights reserved. // Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. -using System.Collections.Concurrent; using Duende.AccessTokenManagement.OTel; using Microsoft.Extensions.Caching.Hybrid; using Microsoft.Extensions.DependencyInjection; @@ -13,10 +12,12 @@ namespace Duende.AccessTokenManagement.Internal; internal class ClientCredentialsTokenManager( AccessTokenManagementMetrics metrics, IOptions options, - [FromKeyedServices(ServiceProviderKeys.ClientCredentialsTokenCache)] HybridCache cache, + [FromKeyedServices(ServiceProviderKeys.ClientCredentialsTokenCache)] + HybridCache cache, TimeProvider time, IClientCredentialsTokenEndpoint client, IClientCredentialsCacheKeyGenerator cacheKeyGenerator, + ClientCredentialsCacheDurationStore cacheDurationAutoTuningStore, ILogger logger ) : IClientCredentialsTokenManager { @@ -25,11 +26,6 @@ ILogger logger // inside the factory. private const string ThrownInsideFactoryExceptionKey = "Duende.AccessTokenManagement.ThrownInside"; - // We're assuming that the cache duration for access tokens will remain (relatively) stable - // First time we acquire an access token, don't yet know how long it will be valid, so we're assuming - // a specific period. However, after that, we'll use the actual expiration time to set the cache duration. - private readonly ConcurrentDictionary _cacheDurationAutoTuning = new(); - private readonly ClientCredentialsTokenManagementOptions _options = options.Value; public async Task> GetAccessTokenAsync( @@ -41,9 +37,7 @@ public async Task> GetAccessTokenAsync( parameters ??= new TokenRequestParameters(); - var cacheExpiration = _options.UseCacheAutoTuning - ? _cacheDurationAutoTuning.GetValueOrDefault(cacheKey, _options.DefaultCacheLifetime) - : _options.DefaultCacheLifetime; + var cacheExpiration = cacheDurationAutoTuningStore.GetExpiration(cacheKey); // On force renewal, don't read from the cache, so we always get a new token. var disableDistributedCacheRead = parameters.ForceTokenRenewal @@ -75,7 +69,8 @@ public async Task> GetAccessTokenAsync( { // This exception is thrown if there was a failure while retrieving an access token. We // don't want to cache this failure, so we throw an exception to bypass the cache action. - logger.WillNotCacheTokenResultWithError(LogLevel.Debug, clientName, ex.Failure.Error, ex.Failure.ErrorDescription); + logger.WillNotCacheTokenResultWithError(LogLevel.Debug, clientName, ex.Failure.Error, + ex.Failure.ErrorDescription); return ex.Failure; } catch (Exception ex) when (!ex.Data.Contains(ThrownInsideFactoryExceptionKey)) @@ -135,23 +130,14 @@ private async Task RequestToken(ClientCredentialsCacheKe // See if we need to record how long this access token is valid, to be used the next time // this access token is used. - if (_options.UseCacheAutoTuning - && token.Expiration != DateTimeOffset.MaxValue) - { - // Calculate how long this access token should be valid in the cache. - // Note, the expiration time was just calculated by adding time.GetUTcNow() to the token lifetime. - // so for now it's safe to subtract this time from the expiration time. - _cacheDurationAutoTuning[cacheKey] = token.Expiration - - time.GetUtcNow() - - TimeSpan.FromSeconds(_options.CacheLifetimeBuffer); - - logger.CachingAccessToken(LogLevel.Debug, clientName, token.Expiration); - } + var cacheDuration = cacheDurationAutoTuningStore.SetExpiration(cacheKey, token.Expiration); + logger.CachingAccessToken(LogLevel.Debug, clientName, cacheDuration); return token; } - public async Task DeleteAccessTokenAsync(ClientCredentialsClientName clientName, TokenRequestParameters? parameters = null, + public async Task DeleteAccessTokenAsync(ClientCredentialsClientName clientName, + TokenRequestParameters? parameters = null, CT ct = default) { var cacheKey = cacheKeyGenerator.GenerateKey(clientName, parameters); diff --git a/access-token-management/src/AccessTokenManagement/Internal/Generated/Microsoft.Gen.Logging/Microsoft.Gen.Logging.LoggingGenerator/Logging.cs b/access-token-management/src/AccessTokenManagement/Internal/Generated/Microsoft.Gen.Logging/Microsoft.Gen.Logging.LoggingGenerator/Logging.cs index 092ce02e2..ddcae7822 100644 --- a/access-token-management/src/AccessTokenManagement/Internal/Generated/Microsoft.Gen.Logging/Microsoft.Gen.Logging.LoggingGenerator/Logging.cs +++ b/access-token-management/src/AccessTokenManagement/Internal/Generated/Microsoft.Gen.Logging/Microsoft.Gen.Logging.LoggingGenerator/Logging.cs @@ -973,7 +973,7 @@ public static void AuthorizationServerSuppliedNewNonce(this global::Microsoft.E /// Logs "Caching access token for client: {ClientName}. Expiration: {Expiration}". /// [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Gen.Logging", "9.9.0.0")] - public static void CachingAccessToken(this global::Microsoft.Extensions.Logging.ILogger logger, global::Microsoft.Extensions.Logging.LogLevel logLevel, global::Duende.AccessTokenManagement.ClientCredentialsClientName clientName, global::System.DateTimeOffset expiration) + public static void CachingAccessToken(this global::Microsoft.Extensions.Logging.ILogger logger, global::Microsoft.Extensions.Logging.LogLevel logLevel, global::Duende.AccessTokenManagement.ClientCredentialsClientName clientName, global::System.TimeSpan cacheDuration) { if (!logger.IsEnabled(logLevel)) { @@ -985,7 +985,7 @@ public static void CachingAccessToken(this global::Microsoft.Extensions.Logging _ = state.ReserveTagSpace(3); state.TagArray[2] = new("{OriginalFormat}", "Caching access token for client: {ClientName}. Expiration: {Expiration}"); state.TagArray[1] = new("ClientName", clientName.ToString()); - state.TagArray[0] = new("Expiration", expiration); + state.TagArray[0] = new("CacheDuration", cacheDuration); logger.Log( logLevel, diff --git a/access-token-management/src/AccessTokenManagement/ServiceCollectionExtensions.cs b/access-token-management/src/AccessTokenManagement/ServiceCollectionExtensions.cs index 95a298522..86cac2c71 100644 --- a/access-token-management/src/AccessTokenManagement/ServiceCollectionExtensions.cs +++ b/access-token-management/src/AccessTokenManagement/ServiceCollectionExtensions.cs @@ -41,6 +41,7 @@ public static ClientCredentialsTokenManagementBuilder AddClientCredentialsTokenM this IServiceCollection services) { services.TryAddTransient(); + services.TryAddSingleton(); services.AddHybridCache(); // Add a default serializer for ClientCredentialsToken diff --git a/access-token-management/test/AccessTokenManagement.Tests/ClientTokenManagementTests.cs b/access-token-management/test/AccessTokenManagement.Tests/ClientTokenManagementTests.cs index 8c407b88c..c8cd5bce6 100644 --- a/access-token-management/test/AccessTokenManagement.Tests/ClientTokenManagementTests.cs +++ b/access-token-management/test/AccessTokenManagement.Tests/ClientTokenManagementTests.cs @@ -540,4 +540,65 @@ [new KeyValuePair("DPoP-Nonce", "some_nonce")], DPoPJsonWebKey = The.JsonWebKey }); } + + [Fact] + public async Task Cache_auto_tuning_should_persist_across_transient_manager_instances() + { + var tokenExpiry = (int)TimeSpan.FromDays(7).TotalSeconds; + + var fakeCache = new FakeHybridCache(); + _services.AddSingleton(fakeCache); + + _services.AddClientCredentialsTokenManagement(options => + { + options.UseCacheAutoTuning = true; + options.DefaultCacheLifetime = TimeSpan.FromSeconds(30); + options.LocalCacheExpiration = TimeSpan.FromMinutes(10); + options.CacheLifetimeBuffer = 60; + }) + .AddClient("test", client => Some.ClientCredentialsClient(client)); + + _mockHttp.Expect("/connect/token") + .Respond(_ => Some.TokenHttpResponse(Some.Token() with + { + expires_in = tokenExpiry + })); + + _services.AddHttpClient(ClientCredentialsTokenManagementDefaults.BackChannelHttpClientName) + .ConfigurePrimaryHttpMessageHandler(() => _mockHttp); + + var services = _services.BuildServiceProvider(); + + // First request with first manager instance + var firstManager = services.GetRequiredService(); + var firstToken = await firstManager.GetAccessTokenAsync(ClientCredentialsClientName.Parse("test")).GetToken(); + _mockHttp.VerifyNoOutstandingExpectation(); + + firstToken.Expiration.ShouldBe(The.CurrentDateTime.Add(TimeSpan.FromSeconds(tokenExpiry))); + + // Get the cache expiration used for the first request + var firstRequestExpiration = fakeCache.LastOptions?.Expiration; + // The first request doesn't know the token lifetime yet, so it should use DefaultCacheLifetime (30 seconds) + firstRequestExpiration.ShouldBe(TimeSpan.FromSeconds(30)); + + _mockHttp.Expect("/connect/token") + .Respond(_ => Some.TokenHttpResponse(Some.Token() with + { + expires_in = tokenExpiry + })); + + var secondManager = services.GetRequiredService(); + var secondToken = await secondManager.GetAccessTokenAsync(ClientCredentialsClientName.Parse("test")).GetToken(); + _mockHttp.VerifyNoOutstandingExpectation(); + + secondToken.Expiration.ShouldBe(The.CurrentDateTime.Add(TimeSpan.FromSeconds(tokenExpiry))); + + // Get the cache expiration used for the second request + var secondRequestExpiration = fakeCache.LastOptions?.Expiration; + + // Expect the lifetime to be auto-tuned based on the first token's lifetime minus the buffer + var expectedExpiration = TimeSpan.FromSeconds(tokenExpiry) - TimeSpan.FromSeconds(60); + secondRequestExpiration.ShouldBe(expectedExpiration, + "Second request should use the auto-tuned cache duration learned from the first request"); + } } diff --git a/access-token-management/test/AccessTokenManagement.Tests/Framework/FakeHybridCache.cs b/access-token-management/test/AccessTokenManagement.Tests/Framework/FakeHybridCache.cs index 77501c062..dd2ede84a 100644 --- a/access-token-management/test/AccessTokenManagement.Tests/Framework/FakeHybridCache.cs +++ b/access-token-management/test/AccessTokenManagement.Tests/Framework/FakeHybridCache.cs @@ -13,20 +13,27 @@ public class FakeHybridCache : HybridCache public Action OnGetOrCreate = () => { }; - public override async ValueTask GetOrCreateAsync(string key, TState state, Func> factory, HybridCacheEntryOptions? options = null, + public HybridCacheEntryOptions? LastOptions { get; private set; } + + public override async ValueTask GetOrCreateAsync(string key, TState state, + Func> factory, HybridCacheEntryOptions? options = null, IEnumerable? tags = null, CancellationToken cancellationToken = new()) { CacheKey = key; + LastOptions = options; Interlocked.Increment(ref GetOrCreateCount); OnGetOrCreate(); return await factory(state, cancellationToken); } - public override ValueTask SetAsync(string key, T value, HybridCacheEntryOptions? options = null, IEnumerable? tags = null, + public override ValueTask SetAsync(string key, T value, HybridCacheEntryOptions? options = null, + IEnumerable? tags = null, CancellationToken cancellationToken = new()) => throw new NotImplementedException(); - public override ValueTask RemoveAsync(string key, CancellationToken cancellationToken = new()) => throw new NotImplementedException(); + public override ValueTask RemoveAsync(string key, CancellationToken cancellationToken = new()) => + throw new NotImplementedException(); - public override ValueTask RemoveByTagAsync(string tag, CancellationToken cancellationToken = new()) => throw new NotImplementedException(); + public override ValueTask RemoveByTagAsync(string tag, CancellationToken cancellationToken = new()) => + throw new NotImplementedException(); }