Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 190 additions & 162 deletions Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ internal sealed class TokenCredentialCache : IDisposable
private readonly TimeSpan? userDefinedBackgroundTokenCredentialRefreshInterval;

private readonly SemaphoreSlim isTokenRefreshingLock = new SemaphoreSlim(1);
private readonly object backgroundRefreshLock = new object();
private readonly object backgroundRefreshLock = new object();
private readonly string defaultScope;
private readonly string? overrideScope;

private TimeSpan? systemBackgroundTokenCredentialRefreshInterval;
private Task<AccessToken>? currentRefreshOperation = null;
Expand All @@ -57,21 +59,21 @@ internal TokenCredentialCache(
TokenCredential tokenCredential,
Uri accountEndpoint,
TimeSpan? backgroundTokenCredentialRefreshInterval)
{
this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential));

if (accountEndpoint == null)
{
throw new ArgumentNullException(nameof(accountEndpoint));
{
this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential));
if (accountEndpoint == null)
{
throw new ArgumentNullException(nameof(accountEndpoint));
}

string? scopeOverride = ConfigurationManager.AADScopeOverrideValue(defaultValue: null);
this.defaultScope = string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host);
this.overrideScope = !string.IsNullOrEmpty(scopeOverride) ? scopeOverride : null;

this.tokenRequestContext = new TokenRequestContext(new string[]
{
!string.IsNullOrEmpty(scopeOverride)
? scopeOverride
: string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host)
{
this.overrideScope ?? this.defaultScope
});

if (backgroundTokenCredentialRefreshInterval.HasValue)
Expand Down Expand Up @@ -135,159 +137,185 @@ public void Dispose()

private async Task<AccessToken> GetNewTokenAsync(
ITrace trace)
{
// Use a local variable to avoid the possibility the task gets changed
// between the null check and the await operation.
Task<AccessToken>? currentTask = this.currentRefreshOperation;
if (currentTask != null)
{
// The refresh is already occurring wait on the existing task
return await currentTask;
}

try
{
await this.isTokenRefreshingLock.WaitAsync();

// avoid doing the await in the semaphore to unblock the parallel requests
if (this.currentRefreshOperation == null)
{
// ValueTask can not be awaited multiple times
currentTask = this.RefreshCachedTokenWithRetryHelperAsync(trace).AsTask();
this.currentRefreshOperation = currentTask;
}
else
{
currentTask = this.currentRefreshOperation;
}
}
finally
{
this.isTokenRefreshingLock.Release();
}

return await currentTask;
}

{
// Use a local variable to avoid the possibility the task gets changed
// between the null check and the await operation.
Task<AccessToken>? currentTask = this.currentRefreshOperation;
if (currentTask != null)
{
// The refresh is already occurring wait on the existing task
return await currentTask;
}
try
{
await this.isTokenRefreshingLock.WaitAsync();
// avoid doing the await in the semaphore to unblock the parallel requests
if (this.currentRefreshOperation == null)
{
// ValueTask can not be awaited multiple times
currentTask = this.RefreshCachedTokenWithRetryHelperAsync(trace).AsTask();
this.currentRefreshOperation = currentTask;
}
else
{
currentTask = this.currentRefreshOperation;
}
}
finally
{
this.isTokenRefreshingLock.Release();
}
return await currentTask;
}
private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
ITrace trace)
{
try
{
Exception? lastException = null;
const int totalRetryCount = 2;
for (int retry = 0; retry < totalRetryCount; retry++)
{
if (this.cancellationToken.IsCancellationRequested)
{
DefaultTrace.TraceInformation(
"Stop RefreshTokenWithIndefiniteRetries because cancellation is requested");

break;
}

using (ITrace getTokenTrace = trace.StartChild(
name: nameof(this.RefreshCachedTokenWithRetryHelperAsync),
component: TraceComponent.Authorization,
level: Tracing.TraceLevel.Info))
{
try
{
this.cachedAccessToken = await this.tokenCredential.GetTokenAsync(
requestContext: this.tokenRequestContext,
cancellationToken: this.cancellationToken);

if (!this.cachedAccessToken.HasValue)
{
throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token.");
}

if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow)
{
throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}");
}

if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue)
{
double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage;

// Ensure the background refresh interval is a valid range.
refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds);
refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds);
this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds);
}

return this.cachedAccessToken.Value;
}
catch (RequestFailedException requestFailedException)
{
lastException = requestFailedException;
getTokenTrace.AddDatum(
$"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
requestFailedException.Message);

DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");

// Don't retry on auth failures
if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedException.Status == (int)HttpStatusCode.Forbidden)
{
this.cachedAccessToken = default;
throw;
}
}
catch (OperationCanceledException operationCancelled)
{
lastException = operationCancelled;
getTokenTrace.AddDatum(
$"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
operationCancelled.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");

throw CosmosExceptionFactory.CreateRequestTimeoutException(
message: ClientResources.FailedToGetAadToken,
headers: new Headers()
{
SubStatusCode = SubStatusCodes.FailedToGetAadToken,
},
innerException: lastException,
trace: getTokenTrace);
}
catch (Exception exception)
{
lastException = exception;
getTokenTrace.AddDatum(
$"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
exception.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");
}
}
}

if (lastException == null)
{
throw new ArgumentException("Last exception is null.");
}

// The retries have been exhausted. Throw the last exception.
throw lastException;
}
finally
{
try
{
await this.isTokenRefreshingLock.WaitAsync();
this.currentRefreshOperation = null;
}
finally
{
this.isTokenRefreshingLock.Release();
}
}
{
try
{
Exception? lastException = null;
const int totalRetryCount = 2;
for (int retry = 0; retry < totalRetryCount; retry++)
{
if (this.cancellationToken.IsCancellationRequested)
{
DefaultTrace.TraceInformation(
"Stop RefreshTokenWithIndefiniteRetries because cancellation is requested");

break;
}

using (ITrace getTokenTrace = trace.StartChild(
name: nameof(this.RefreshCachedTokenWithRetryHelperAsync),
component: TraceComponent.Authorization,
level: Tracing.TraceLevel.Info))
{
try
{
if (this.overrideScope != null)
{
try
{
TokenRequestContext overrideContext = new TokenRequestContext(new string[] { this.overrideScope });
this.cachedAccessToken = await this.GetAndValidateTokenAsync(overrideContext);
}
catch (Exception ex)
{
DefaultTrace.TraceError($"TokenCredential.GetTokenAsync failed with override scope '{this.overrideScope}': {ex.Message}. Retrying with default scope '{this.defaultScope}'.");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapping it with our exception types gives us flexibility to include extra context like scopes.
With side-affect of mis-interpretting it as Cosmos issue. Thoughs?

/cc: @FabianMeiswinkel


TokenRequestContext defaultContext = new TokenRequestContext(new string[] { this.defaultScope });
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please share what's the behavior if multiple scopes are specified?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ex: What happens if both account and generic ones are included.

this.cachedAccessToken = await this.GetAndValidateTokenAsync(defaultContext);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Orrides are final right?

}
}
else
{
TokenRequestContext defaultContext = new TokenRequestContext(new string[] { this.defaultScope });
this.cachedAccessToken = await this.GetAndValidateTokenAsync(defaultContext);
}

if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue)
{
double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage;

// Ensure the background refresh interval is a valid range.
refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds);
refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds);
this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds);
}

return this.cachedAccessToken.Value;
}
catch (RequestFailedException requestFailedException)
{
lastException = requestFailedException;
getTokenTrace.AddDatum(
$"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
requestFailedException.Message);

DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.overrideScope ?? this.defaultScope)}, retry = {retry}, Exception = {lastException.Message}");

// Don't retry on auth failures
if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedException.Status == (int)HttpStatusCode.Forbidden)
{
this.cachedAccessToken = default;
throw;
}
}
catch (OperationCanceledException operationCancelled)
{
lastException = operationCancelled;
getTokenTrace.AddDatum(
$"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
operationCancelled.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.overrideScope ?? this.defaultScope)}, retry = {retry}, Exception = {lastException.Message}");

throw CosmosExceptionFactory.CreateRequestTimeoutException(
message: ClientResources.FailedToGetAadToken,
headers: new Headers()
{
SubStatusCode = SubStatusCodes.FailedToGetAadToken,
},
innerException: lastException,
trace: getTokenTrace);
}
catch (Exception exception)
{
lastException = exception;
getTokenTrace.AddDatum(
$"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
exception.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.overrideScope ?? this.defaultScope)}, retry = {retry}, Exception = {lastException.Message}");
}
}
}

if (lastException == null)
{
throw new ArgumentException("Last exception is null.");
}

// The retries have been exhausted. Throw the last exception.
throw lastException;
}
finally
{
try
{
await this.isTokenRefreshingLock.WaitAsync();
this.currentRefreshOperation = null;
}
finally
{
this.isTokenRefreshingLock.Release();
}
}
}

private async Task<AccessToken> GetAndValidateTokenAsync(TokenRequestContext requestContext)
{
AccessToken? cachedAccessToken = await this.tokenCredential.GetTokenAsync(
requestContext: requestContext,
cancellationToken: this.cancellationToken);

if (!cachedAccessToken.HasValue)
{
throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token.");
Copy link
Copy Markdown
Member

@kirankumarkolli kirankumarkolli Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include Scopes in all exception message.

}

if (cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow)
{
throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{cachedAccessToken.Value.ExpiresOn:O}");
}

return cachedAccessToken.Value;
}

#pragma warning disable VSTHRD100 // Avoid async void methods
Expand Down
Loading