Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Duende Software. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.

using System.Collections.Concurrent;
using Microsoft.Extensions.Options;

namespace Duende.AccessTokenManagement.Internal;

/// <summary>
/// Store for cache duration auto-tuning state.
/// </summary>
internal sealed class ClientCredentialsCacheDurationStore(
IOptions<ClientCredentialsTokenManagementOptions> options,
TimeProvider time)
{
private readonly ClientCredentialsTokenManagementOptions _options = options.Value;
private readonly ConcurrentDictionary<ClientCredentialsCacheKey, TimeSpan> _cacheDurations = new();

/// <summary>
/// Gets the cache duration for a given cache key, or returns the default value if not found.
/// </summary>
public TimeSpan GetExpiration(ClientCredentialsCacheKey cacheKey)
{
var cacheExpiration = _options.UseCacheAutoTuning
? _cacheDurations.GetValueOrDefault(cacheKey, _options.DefaultCacheLifetime)
: _options.DefaultCacheLifetime;
return cacheExpiration;
}

/// <summary>
/// Sets the cache duration for a given cache key.
/// </summary>
public TimeSpan SetExpiration(ClientCredentialsCacheKey cacheKey, DateTimeOffset expiration)
{
if (!_options.UseCacheAutoTuning
|| expiration == DateTimeOffset.MaxValue)
{
return _options.DefaultCacheLifetime;
}

// Calculate how long this access token should be valid in the cache.
// Note, the expiration time was just calculated by adding time.GetUTcNow() to the token lifetime.
// so for now it's safe to subtract this time from the expiration time.

var calculated = expiration
- time.GetUtcNow()
- TimeSpan.FromSeconds(_options.CacheLifetimeBuffer);

_cacheDurations[cacheKey] = calculated;

return calculated;
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Duende Software. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.

using System.Collections.Concurrent;
using Duende.AccessTokenManagement.OTel;
using Microsoft.Extensions.Caching.Hybrid;
using Microsoft.Extensions.DependencyInjection;
Expand All @@ -13,10 +12,12 @@ namespace Duende.AccessTokenManagement.Internal;
internal class ClientCredentialsTokenManager(
AccessTokenManagementMetrics metrics,
IOptions<ClientCredentialsTokenManagementOptions> options,
[FromKeyedServices(ServiceProviderKeys.ClientCredentialsTokenCache)] HybridCache cache,
[FromKeyedServices(ServiceProviderKeys.ClientCredentialsTokenCache)]
HybridCache cache,
TimeProvider time,
IClientCredentialsTokenEndpoint client,
IClientCredentialsCacheKeyGenerator cacheKeyGenerator,
ClientCredentialsCacheDurationStore cacheDurationAutoTuningStore,
ILogger<ClientCredentialsTokenManager> logger
) : IClientCredentialsTokenManager
{
Expand All @@ -25,11 +26,6 @@ ILogger<ClientCredentialsTokenManager> logger
// inside the factory.
private const string ThrownInsideFactoryExceptionKey = "Duende.AccessTokenManagement.ThrownInside";

// We're assuming that the cache duration for access tokens will remain (relatively) stable
// First time we acquire an access token, don't yet know how long it will be valid, so we're assuming
// a specific period. However, after that, we'll use the actual expiration time to set the cache duration.
private readonly ConcurrentDictionary<ClientCredentialsCacheKey, TimeSpan> _cacheDurationAutoTuning = new();

private readonly ClientCredentialsTokenManagementOptions _options = options.Value;

public async Task<TokenResult<ClientCredentialsToken>> GetAccessTokenAsync(
Expand All @@ -41,9 +37,7 @@ public async Task<TokenResult<ClientCredentialsToken>> GetAccessTokenAsync(

parameters ??= new TokenRequestParameters();

var cacheExpiration = _options.UseCacheAutoTuning
? _cacheDurationAutoTuning.GetValueOrDefault(cacheKey, _options.DefaultCacheLifetime)
: _options.DefaultCacheLifetime;
var cacheExpiration = cacheDurationAutoTuningStore.GetExpiration(cacheKey);

// On force renewal, don't read from the cache, so we always get a new token.
var disableDistributedCacheRead = parameters.ForceTokenRenewal
Expand Down Expand Up @@ -75,7 +69,8 @@ public async Task<TokenResult<ClientCredentialsToken>> GetAccessTokenAsync(
{
// This exception is thrown if there was a failure while retrieving an access token. We
// don't want to cache this failure, so we throw an exception to bypass the cache action.
logger.WillNotCacheTokenResultWithError(LogLevel.Debug, clientName, ex.Failure.Error, ex.Failure.ErrorDescription);
logger.WillNotCacheTokenResultWithError(LogLevel.Debug, clientName, ex.Failure.Error,
ex.Failure.ErrorDescription);
return ex.Failure;
}
catch (Exception ex) when (!ex.Data.Contains(ThrownInsideFactoryExceptionKey))
Expand Down Expand Up @@ -135,23 +130,14 @@ private async Task<ClientCredentialsToken> RequestToken(ClientCredentialsCacheKe

// See if we need to record how long this access token is valid, to be used the next time
// this access token is used.
if (_options.UseCacheAutoTuning
&& token.Expiration != DateTimeOffset.MaxValue)
{
// Calculate how long this access token should be valid in the cache.
// Note, the expiration time was just calculated by adding time.GetUTcNow() to the token lifetime.
// so for now it's safe to subtract this time from the expiration time.
_cacheDurationAutoTuning[cacheKey] = token.Expiration
- time.GetUtcNow()
- TimeSpan.FromSeconds(_options.CacheLifetimeBuffer);

logger.CachingAccessToken(LogLevel.Debug, clientName, token.Expiration);
}
var cacheDuration = cacheDurationAutoTuningStore.SetExpiration(cacheKey, token.Expiration);
logger.CachingAccessToken(LogLevel.Debug, clientName, cacheDuration);

return token;
}

public async Task DeleteAccessTokenAsync(ClientCredentialsClientName clientName, TokenRequestParameters? parameters = null,
public async Task DeleteAccessTokenAsync(ClientCredentialsClientName clientName,
TokenRequestParameters? parameters = null,
CT ct = default)
{
var cacheKey = cacheKeyGenerator.GenerateKey(clientName, parameters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ public static void AuthorizationServerSuppliedNewNonce(this global::Microsoft.E
/// Logs "Caching access token for client: {ClientName}. Expiration: {Expiration}".
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.Gen.Logging", "9.9.0.0")]
public static void CachingAccessToken(this global::Microsoft.Extensions.Logging.ILogger logger, global::Microsoft.Extensions.Logging.LogLevel logLevel, global::Duende.AccessTokenManagement.ClientCredentialsClientName clientName, global::System.DateTimeOffset expiration)
public static void CachingAccessToken(this global::Microsoft.Extensions.Logging.ILogger logger, global::Microsoft.Extensions.Logging.LogLevel logLevel, global::Duende.AccessTokenManagement.ClientCredentialsClientName clientName, global::System.TimeSpan cacheDuration)
{
if (!logger.IsEnabled(logLevel))
{
Expand All @@ -985,7 +985,7 @@ public static void CachingAccessToken(this global::Microsoft.Extensions.Logging
_ = state.ReserveTagSpace(3);
state.TagArray[2] = new("{OriginalFormat}", "Caching access token for client: {ClientName}. Expiration: {Expiration}");
state.TagArray[1] = new("ClientName", clientName.ToString());
state.TagArray[0] = new("Expiration", expiration);
state.TagArray[0] = new("CacheDuration", cacheDuration);

logger.Log(
logLevel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public static ClientCredentialsTokenManagementBuilder AddClientCredentialsTokenM
this IServiceCollection services)
{
services.TryAddTransient<IClientCredentialsTokenManager, ClientCredentialsTokenManager>();
services.TryAddSingleton<ClientCredentialsCacheDurationStore>();
services.AddHybridCache();

// Add a default serializer for ClientCredentialsToken
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,4 +540,65 @@ [new KeyValuePair<string, string>("DPoP-Nonce", "some_nonce")],
DPoPJsonWebKey = The.JsonWebKey
});
}

[Fact]
public async Task Cache_auto_tuning_should_persist_across_transient_manager_instances()
{
var tokenExpiry = (int)TimeSpan.FromDays(7).TotalSeconds;

var fakeCache = new FakeHybridCache();
_services.AddSingleton<HybridCache>(fakeCache);

_services.AddClientCredentialsTokenManagement(options =>
{
options.UseCacheAutoTuning = true;
options.DefaultCacheLifetime = TimeSpan.FromSeconds(30);
options.LocalCacheExpiration = TimeSpan.FromMinutes(10);
options.CacheLifetimeBuffer = 60;
})
.AddClient("test", client => Some.ClientCredentialsClient(client));

_mockHttp.Expect("/connect/token")
.Respond(_ => Some.TokenHttpResponse(Some.Token() with
{
expires_in = tokenExpiry
}));

_services.AddHttpClient(ClientCredentialsTokenManagementDefaults.BackChannelHttpClientName)
.ConfigurePrimaryHttpMessageHandler(() => _mockHttp);

var services = _services.BuildServiceProvider();

// First request with first manager instance
var firstManager = services.GetRequiredService<IClientCredentialsTokenManager>();
var firstToken = await firstManager.GetAccessTokenAsync(ClientCredentialsClientName.Parse("test")).GetToken();
_mockHttp.VerifyNoOutstandingExpectation();

firstToken.Expiration.ShouldBe(The.CurrentDateTime.Add(TimeSpan.FromSeconds(tokenExpiry)));

// Get the cache expiration used for the first request
var firstRequestExpiration = fakeCache.LastOptions?.Expiration;
// The first request doesn't know the token lifetime yet, so it should use DefaultCacheLifetime (30 seconds)
firstRequestExpiration.ShouldBe(TimeSpan.FromSeconds(30));

_mockHttp.Expect("/connect/token")
.Respond(_ => Some.TokenHttpResponse(Some.Token() with
{
expires_in = tokenExpiry
}));

var secondManager = services.GetRequiredService<IClientCredentialsTokenManager>();
var secondToken = await secondManager.GetAccessTokenAsync(ClientCredentialsClientName.Parse("test")).GetToken();
_mockHttp.VerifyNoOutstandingExpectation();

secondToken.Expiration.ShouldBe(The.CurrentDateTime.Add(TimeSpan.FromSeconds(tokenExpiry)));

// Get the cache expiration used for the second request
var secondRequestExpiration = fakeCache.LastOptions?.Expiration;

// Expect the lifetime to be auto-tuned based on the first token's lifetime minus the buffer
var expectedExpiration = TimeSpan.FromSeconds(tokenExpiry) - TimeSpan.FromSeconds(60);
secondRequestExpiration.ShouldBe(expectedExpiration,
"Second request should use the auto-tuned cache duration learned from the first request");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,27 @@ public class FakeHybridCache : HybridCache

public Action OnGetOrCreate = () => { };

public override async ValueTask<T> GetOrCreateAsync<TState, T>(string key, TState state, Func<TState, CancellationToken, ValueTask<T>> factory, HybridCacheEntryOptions? options = null,
public HybridCacheEntryOptions? LastOptions { get; private set; }

public override async ValueTask<T> GetOrCreateAsync<TState, T>(string key, TState state,
Func<TState, CancellationToken, ValueTask<T>> factory, HybridCacheEntryOptions? options = null,
IEnumerable<string>? tags = null, CancellationToken cancellationToken = new())
{
CacheKey = key;
LastOptions = options;
Interlocked.Increment(ref GetOrCreateCount);
OnGetOrCreate();
return await factory(state, cancellationToken);
}

public override ValueTask SetAsync<T>(string key, T value, HybridCacheEntryOptions? options = null, IEnumerable<string>? tags = null,
public override ValueTask SetAsync<T>(string key, T value, HybridCacheEntryOptions? options = null,
IEnumerable<string>? tags = null,
CancellationToken cancellationToken = new()) =>
throw new NotImplementedException();

public override ValueTask RemoveAsync(string key, CancellationToken cancellationToken = new()) => throw new NotImplementedException();
public override ValueTask RemoveAsync(string key, CancellationToken cancellationToken = new()) =>
throw new NotImplementedException();

public override ValueTask RemoveByTagAsync(string tag, CancellationToken cancellationToken = new()) => throw new NotImplementedException();
public override ValueTask RemoveByTagAsync(string tag, CancellationToken cancellationToken = new()) =>
throw new NotImplementedException();
}