-
Notifications
You must be signed in to change notification settings - Fork 533
TokenCredentialCache: Adds a fallback mechanism to AAD scope override. #5337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,7 +45,9 @@ internal sealed class TokenCredentialCache : IDisposable | |
| private readonly TimeSpan? userDefinedBackgroundTokenCredentialRefreshInterval; | ||
|
|
||
| private readonly SemaphoreSlim isTokenRefreshingLock = new SemaphoreSlim(1); | ||
| private readonly object backgroundRefreshLock = new object(); | ||
| private readonly object backgroundRefreshLock = new object(); | ||
| private readonly string defaultScope; | ||
| private readonly string? overrideScope; | ||
|
|
||
| private TimeSpan? systemBackgroundTokenCredentialRefreshInterval; | ||
| private Task<AccessToken>? currentRefreshOperation = null; | ||
|
|
@@ -57,21 +59,21 @@ internal TokenCredentialCache( | |
| TokenCredential tokenCredential, | ||
| Uri accountEndpoint, | ||
| TimeSpan? backgroundTokenCredentialRefreshInterval) | ||
| { | ||
| this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential)); | ||
|
|
||
| if (accountEndpoint == null) | ||
| { | ||
| throw new ArgumentNullException(nameof(accountEndpoint)); | ||
| { | ||
| this.tokenCredential = tokenCredential ?? throw new ArgumentNullException(nameof(tokenCredential)); | ||
| if (accountEndpoint == null) | ||
| { | ||
| throw new ArgumentNullException(nameof(accountEndpoint)); | ||
| } | ||
|
|
||
| string? scopeOverride = ConfigurationManager.AADScopeOverrideValue(defaultValue: null); | ||
| this.defaultScope = string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host); | ||
| this.overrideScope = !string.IsNullOrEmpty(scopeOverride) ? scopeOverride : null; | ||
|
|
||
| this.tokenRequestContext = new TokenRequestContext(new string[] | ||
| { | ||
| !string.IsNullOrEmpty(scopeOverride) | ||
| ? scopeOverride | ||
| : string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host) | ||
| { | ||
| this.overrideScope ?? this.defaultScope | ||
| }); | ||
|
|
||
| if (backgroundTokenCredentialRefreshInterval.HasValue) | ||
|
|
@@ -135,159 +137,185 @@ public void Dispose() | |
|
|
||
| private async Task<AccessToken> GetNewTokenAsync( | ||
| ITrace trace) | ||
| { | ||
| // Use a local variable to avoid the possibility the task gets changed | ||
| // between the null check and the await operation. | ||
| Task<AccessToken>? currentTask = this.currentRefreshOperation; | ||
| if (currentTask != null) | ||
| { | ||
| // The refresh is already occurring wait on the existing task | ||
| return await currentTask; | ||
| } | ||
|
|
||
| try | ||
| { | ||
| await this.isTokenRefreshingLock.WaitAsync(); | ||
|
|
||
| // avoid doing the await in the semaphore to unblock the parallel requests | ||
| if (this.currentRefreshOperation == null) | ||
| { | ||
| // ValueTask can not be awaited multiple times | ||
| currentTask = this.RefreshCachedTokenWithRetryHelperAsync(trace).AsTask(); | ||
| this.currentRefreshOperation = currentTask; | ||
| } | ||
| else | ||
| { | ||
| currentTask = this.currentRefreshOperation; | ||
| } | ||
| } | ||
| finally | ||
| { | ||
| this.isTokenRefreshingLock.Release(); | ||
| } | ||
|
|
||
| return await currentTask; | ||
| } | ||
|
|
||
| { | ||
| // Use a local variable to avoid the possibility the task gets changed | ||
| // between the null check and the await operation. | ||
| Task<AccessToken>? currentTask = this.currentRefreshOperation; | ||
| if (currentTask != null) | ||
| { | ||
| // The refresh is already occurring wait on the existing task | ||
| return await currentTask; | ||
| } | ||
| try | ||
| { | ||
| await this.isTokenRefreshingLock.WaitAsync(); | ||
| // avoid doing the await in the semaphore to unblock the parallel requests | ||
| if (this.currentRefreshOperation == null) | ||
| { | ||
| // ValueTask can not be awaited multiple times | ||
| currentTask = this.RefreshCachedTokenWithRetryHelperAsync(trace).AsTask(); | ||
| this.currentRefreshOperation = currentTask; | ||
| } | ||
| else | ||
| { | ||
| currentTask = this.currentRefreshOperation; | ||
| } | ||
| } | ||
| finally | ||
| { | ||
| this.isTokenRefreshingLock.Release(); | ||
| } | ||
| return await currentTask; | ||
| } | ||
| private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync( | ||
| ITrace trace) | ||
| { | ||
| try | ||
| { | ||
| Exception? lastException = null; | ||
| const int totalRetryCount = 2; | ||
| for (int retry = 0; retry < totalRetryCount; retry++) | ||
| { | ||
| if (this.cancellationToken.IsCancellationRequested) | ||
| { | ||
| DefaultTrace.TraceInformation( | ||
| "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); | ||
|
|
||
| break; | ||
| } | ||
|
|
||
| using (ITrace getTokenTrace = trace.StartChild( | ||
| name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), | ||
| component: TraceComponent.Authorization, | ||
| level: Tracing.TraceLevel.Info)) | ||
| { | ||
| try | ||
| { | ||
| this.cachedAccessToken = await this.tokenCredential.GetTokenAsync( | ||
| requestContext: this.tokenRequestContext, | ||
| cancellationToken: this.cancellationToken); | ||
|
|
||
| if (!this.cachedAccessToken.HasValue) | ||
| { | ||
| throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); | ||
| } | ||
|
|
||
| if (this.cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow) | ||
| { | ||
| throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}"); | ||
| } | ||
|
|
||
| if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) | ||
| { | ||
| double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; | ||
|
|
||
| // Ensure the background refresh interval is a valid range. | ||
| refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); | ||
| refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); | ||
| this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); | ||
| } | ||
|
|
||
| return this.cachedAccessToken.Value; | ||
| } | ||
| catch (RequestFailedException requestFailedException) | ||
| { | ||
| lastException = requestFailedException; | ||
| getTokenTrace.AddDatum( | ||
| $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", | ||
| requestFailedException.Message); | ||
|
|
||
| DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); | ||
|
|
||
| // Don't retry on auth failures | ||
| if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || | ||
| requestFailedException.Status == (int)HttpStatusCode.Forbidden) | ||
| { | ||
| this.cachedAccessToken = default; | ||
| throw; | ||
| } | ||
| } | ||
| catch (OperationCanceledException operationCancelled) | ||
| { | ||
| lastException = operationCancelled; | ||
| getTokenTrace.AddDatum( | ||
| $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", | ||
| operationCancelled.Message); | ||
|
|
||
| DefaultTrace.TraceError( | ||
| $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); | ||
|
|
||
| throw CosmosExceptionFactory.CreateRequestTimeoutException( | ||
| message: ClientResources.FailedToGetAadToken, | ||
| headers: new Headers() | ||
| { | ||
| SubStatusCode = SubStatusCodes.FailedToGetAadToken, | ||
| }, | ||
| innerException: lastException, | ||
| trace: getTokenTrace); | ||
| } | ||
| catch (Exception exception) | ||
| { | ||
| lastException = exception; | ||
| getTokenTrace.AddDatum( | ||
| $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", | ||
| exception.Message); | ||
|
|
||
| DefaultTrace.TraceError( | ||
| $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}"); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (lastException == null) | ||
| { | ||
| throw new ArgumentException("Last exception is null."); | ||
| } | ||
|
|
||
| // The retries have been exhausted. Throw the last exception. | ||
| throw lastException; | ||
| } | ||
| finally | ||
| { | ||
| try | ||
| { | ||
| await this.isTokenRefreshingLock.WaitAsync(); | ||
| this.currentRefreshOperation = null; | ||
| } | ||
| finally | ||
| { | ||
| this.isTokenRefreshingLock.Release(); | ||
| } | ||
| } | ||
| { | ||
| try | ||
| { | ||
| Exception? lastException = null; | ||
| const int totalRetryCount = 2; | ||
| for (int retry = 0; retry < totalRetryCount; retry++) | ||
| { | ||
| if (this.cancellationToken.IsCancellationRequested) | ||
| { | ||
| DefaultTrace.TraceInformation( | ||
| "Stop RefreshTokenWithIndefiniteRetries because cancellation is requested"); | ||
|
|
||
| break; | ||
| } | ||
|
|
||
| using (ITrace getTokenTrace = trace.StartChild( | ||
| name: nameof(this.RefreshCachedTokenWithRetryHelperAsync), | ||
| component: TraceComponent.Authorization, | ||
| level: Tracing.TraceLevel.Info)) | ||
| { | ||
| try | ||
| { | ||
| if (this.overrideScope != null) | ||
| { | ||
| try | ||
| { | ||
| TokenRequestContext overrideContext = new TokenRequestContext(new string[] { this.overrideScope }); | ||
| this.cachedAccessToken = await this.GetAndValidateTokenAsync(overrideContext); | ||
| } | ||
| catch (Exception ex) | ||
| { | ||
| DefaultTrace.TraceError($"TokenCredential.GetTokenAsync failed with override scope '{this.overrideScope}': {ex.Message}. Retrying with default scope '{this.defaultScope}'."); | ||
|
|
||
| TokenRequestContext defaultContext = new TokenRequestContext(new string[] { this.defaultScope }); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please share what's the behavior if multiple scopes are specified?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ex: What happens if both account and generic ones are included. |
||
| this.cachedAccessToken = await this.GetAndValidateTokenAsync(defaultContext); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Orrides are final right? |
||
| } | ||
| } | ||
| else | ||
| { | ||
| TokenRequestContext defaultContext = new TokenRequestContext(new string[] { this.defaultScope }); | ||
| this.cachedAccessToken = await this.GetAndValidateTokenAsync(defaultContext); | ||
| } | ||
|
|
||
| if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue) | ||
| { | ||
| double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage; | ||
|
|
||
| // Ensure the background refresh interval is a valid range. | ||
| refreshIntervalInSeconds = Math.Max(refreshIntervalInSeconds, TokenCredentialCache.MinimumTimeBetweenBackgroundRefreshInterval.TotalSeconds); | ||
| refreshIntervalInSeconds = Math.Min(refreshIntervalInSeconds, TokenCredentialCache.MaxBackgroundRefreshInterval.TotalSeconds); | ||
| this.systemBackgroundTokenCredentialRefreshInterval = TimeSpan.FromSeconds(refreshIntervalInSeconds); | ||
| } | ||
|
|
||
| return this.cachedAccessToken.Value; | ||
| } | ||
| catch (RequestFailedException requestFailedException) | ||
| { | ||
| lastException = requestFailedException; | ||
| getTokenTrace.AddDatum( | ||
| $"RequestFailedException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", | ||
| requestFailedException.Message); | ||
|
|
||
| DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", this.overrideScope ?? this.defaultScope)}, retry = {retry}, Exception = {lastException.Message}"); | ||
|
|
||
| // Don't retry on auth failures | ||
| if (requestFailedException.Status == (int)HttpStatusCode.Unauthorized || | ||
| requestFailedException.Status == (int)HttpStatusCode.Forbidden) | ||
| { | ||
| this.cachedAccessToken = default; | ||
| throw; | ||
| } | ||
| } | ||
| catch (OperationCanceledException operationCancelled) | ||
| { | ||
| lastException = operationCancelled; | ||
| getTokenTrace.AddDatum( | ||
| $"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", | ||
| operationCancelled.Message); | ||
|
|
||
| DefaultTrace.TraceError( | ||
| $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.overrideScope ?? this.defaultScope)}, retry = {retry}, Exception = {lastException.Message}"); | ||
|
|
||
| throw CosmosExceptionFactory.CreateRequestTimeoutException( | ||
| message: ClientResources.FailedToGetAadToken, | ||
| headers: new Headers() | ||
| { | ||
| SubStatusCode = SubStatusCodes.FailedToGetAadToken, | ||
| }, | ||
| innerException: lastException, | ||
| trace: getTokenTrace); | ||
| } | ||
| catch (Exception exception) | ||
| { | ||
| lastException = exception; | ||
| getTokenTrace.AddDatum( | ||
| $"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}", | ||
| exception.Message); | ||
|
|
||
| DefaultTrace.TraceError( | ||
| $"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", this.overrideScope ?? this.defaultScope)}, retry = {retry}, Exception = {lastException.Message}"); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (lastException == null) | ||
| { | ||
| throw new ArgumentException("Last exception is null."); | ||
| } | ||
|
|
||
| // The retries have been exhausted. Throw the last exception. | ||
| throw lastException; | ||
| } | ||
| finally | ||
| { | ||
| try | ||
| { | ||
| await this.isTokenRefreshingLock.WaitAsync(); | ||
| this.currentRefreshOperation = null; | ||
| } | ||
| finally | ||
| { | ||
| this.isTokenRefreshingLock.Release(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private async Task<AccessToken> GetAndValidateTokenAsync(TokenRequestContext requestContext) | ||
| { | ||
| AccessToken? cachedAccessToken = await this.tokenCredential.GetTokenAsync( | ||
| requestContext: requestContext, | ||
| cancellationToken: this.cancellationToken); | ||
|
|
||
| if (!cachedAccessToken.HasValue) | ||
| { | ||
| throw new ArgumentNullException("TokenCredential.GetTokenAsync returned a null token."); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Include |
||
| } | ||
|
|
||
| if (cachedAccessToken.Value.ExpiresOn < DateTimeOffset.UtcNow) | ||
| { | ||
| throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{cachedAccessToken.Value.ExpiresOn:O}"); | ||
| } | ||
|
|
||
| return cachedAccessToken.Value; | ||
| } | ||
|
|
||
| #pragma warning disable VSTHRD100 // Avoid async void methods | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrapping it with our exception types gives us flexibility to include extra context like scopes.
With side-affect of mis-interpretting it as Cosmos issue. Thoughs?
/cc: @FabianMeiswinkel