diff --git a/Microsoft.Azure.Cosmos/src/ClientRetryPolicy.cs b/Microsoft.Azure.Cosmos/src/ClientRetryPolicy.cs index 3e81c62d0d..b0af7f1c1d 100644 --- a/Microsoft.Azure.Cosmos/src/ClientRetryPolicy.cs +++ b/Microsoft.Azure.Cosmos/src/ClientRetryPolicy.cs @@ -23,6 +23,7 @@ internal sealed class ClientRetryPolicy : IDocumentClientRetryPolicy private const int RetryIntervalInMS = 1000; // Once we detect failover wait for 1 second before retrying request. private const int MaxRetryCount = 120; private const int MaxServiceUnavailableRetryCount = 1; + private const int MaxDtxRetryCount = 100; // DTX commits carry an idempotency token, making them safe to retry regardless of account topology private readonly IDocumentClientRetryPolicy throttlingRetry; private readonly GlobalEndpointManager globalEndpointManager; @@ -33,6 +34,7 @@ internal sealed class ClientRetryPolicy : IDocumentClientRetryPolicy private int sessionTokenRetryCount; private int serviceUnavailableRetryCount; + private int distributedTransactionRetryCount; private bool isReadRequest; private bool canUseMultipleWriteLocations; private bool isMultiMasterWriteRequest; @@ -117,7 +119,8 @@ public async Task ShouldRetryAsync( ShouldRetryResult shouldRetryResult = await this.ShouldRetryInternalAsync( clientException?.StatusCode, - clientException?.GetSubStatus()); + clientException?.GetSubStatus(), + clientException?.RetryAfter); if (shouldRetryResult != null) { return shouldRetryResult; @@ -131,7 +134,8 @@ public async Task ShouldRetryAsync( { ShouldRetryResult shouldRetryResult = await this.ShouldRetryInternalAsync( cosmosException.StatusCode, - cosmosException.Headers.SubStatusCode); + cosmosException.Headers.SubStatusCode, + cosmosException.RetryAfter); if (shouldRetryResult != null) { return shouldRetryResult; @@ -172,7 +176,8 @@ public async Task ShouldRetryAsync( ShouldRetryResult shouldRetryResult = await this.ShouldRetryInternalAsync( cosmosResponseMessage?.StatusCode, - cosmosResponseMessage?.Headers.SubStatusCode); + cosmosResponseMessage?.Headers.SubStatusCode, + cosmosResponseMessage?.Headers.RetryAfter); if (shouldRetryResult != null) { return shouldRetryResult; @@ -245,7 +250,8 @@ public void OnBeforeSendRequest(DocumentServiceRequest request) private async Task ShouldRetryInternalAsync( HttpStatusCode? statusCode, - SubStatusCodes? subStatusCode) + SubStatusCodes? subStatusCode, + TimeSpan? retryAfter = null) { if (!statusCode.HasValue && (!subStatusCode.HasValue @@ -356,6 +362,55 @@ private async Task ShouldRetryInternalAsync( return this.ShouldRetryOnUnavailableEndpointStatusCodes(); } + // DTX-specific retriable codes. DTX commits carry an idempotency token making them safe + // to retry regardless of account topology (single-master, multi-master, single-region). + // 403.3 WriteForbidden and 429 throttling are already handled for all request types above. + if (this.documentServiceRequest != null + && DistributedTransactionConstants.IsDistributedTransactionRequest( + this.documentServiceRequest.OperationType, + this.documentServiceRequest.ResourceType)) + { + TimeSpan? dtxRetryDelay = null; + + // 408 RequestTimeout: endpoint already marked unavailable above; retry (idempotency ensures safety). + if (statusCode == HttpStatusCode.RequestTimeout) + { + dtxRetryDelay = TimeSpan.Zero; + } + // 449/5352: coordinator race conflict — retry after server-specified backoff. + else if ((int?)statusCode == (int)StatusCodes.RetryWith + && subStatusCode == (SubStatusCodes)DistributedTransactionConstants.DtcCoordinatorRaceConflict) + { + dtxRetryDelay = retryAfter ?? TimeSpan.Zero; + } + // 500/5411-5413: transient infrastructure failures — safe to retry for writes (idempotency guaranteed). + else if (statusCode == HttpStatusCode.InternalServerError + && (subStatusCode == (SubStatusCodes)DistributedTransactionConstants.DtcLedgerFailure + || subStatusCode == (SubStatusCodes)DistributedTransactionConstants.DtcAccountConfigFailure + || subStatusCode == (SubStatusCodes)DistributedTransactionConstants.DtcDispatchFailure)) + { + dtxRetryDelay = TimeSpan.Zero; + } + + if (dtxRetryDelay.HasValue) + { + if (this.distributedTransactionRetryCount++ >= ClientRetryPolicy.MaxDtxRetryCount) + { + DefaultTrace.TraceInformation("ClientRetryPolicy: DTX retry budget exhausted. distributedTransactionRetryCount={0}, StatusCode={1}, SubStatusCode={2}.", + this.distributedTransactionRetryCount, statusCode, subStatusCode); + return ShouldRetryResult.NoRetry(); + } + + DefaultTrace.TraceWarning("ClientRetryPolicy: DTX retriable response (StatusCode={0}, SubStatusCode={1}, attempt={2}). Retrying. Failed Location: {3}", + statusCode, + subStatusCode, + this.distributedTransactionRetryCount, + this.documentServiceRequest?.RequestContext?.LocationEndpointToRoute?.ToString() ?? string.Empty); + + return ShouldRetryResult.RetryAfter(dtxRetryDelay.Value); + } + } + return null; } diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs index 46654f5e2f..02d113fff9 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionCommitter.cs @@ -1,4 +1,4 @@ -// ------------------------------------------------------------ +// ------------------------------------------------------------ // Copyright (c) Microsoft Corporation. All rights reserved. // ------------------------------------------------------------ @@ -17,15 +17,33 @@ namespace Microsoft.Azure.Cosmos internal class DistributedTransactionCommitter { + private static readonly TimeSpan DefaultRetryBaseDelay = TimeSpan.FromSeconds(1); + + internal const int MaxIsRetriableRetryCount = 100; + private readonly IReadOnlyList operations; private readonly CosmosClientContext clientContext; + private readonly TimeSpan retryBaseDelay; + private readonly Random jitter = new Random(); + private readonly Func delayProvider; public DistributedTransactionCommitter( IReadOnlyList operations, CosmosClientContext clientContext) + : this(operations, clientContext, DefaultRetryBaseDelay) + { + } + + internal DistributedTransactionCommitter( + IReadOnlyList operations, + CosmosClientContext clientContext, + TimeSpan retryBaseDelay, + Func delayProvider = null) { this.operations = operations ?? throw new ArgumentNullException(nameof(operations)); this.clientContext = clientContext ?? throw new ArgumentNullException(nameof(clientContext)); + this.retryBaseDelay = retryBaseDelay; + this.delayProvider = delayProvider ?? Task.Delay; } public async Task CommitTransactionAsync(CancellationToken cancellationToken) @@ -43,9 +61,9 @@ await DistributedTransactionCommitterUtils.ResolveCollectionRidsAsync( this.clientContext.SerializerCore, cancellationToken); - return await this.ExecuteCommitAsync(serverRequest, cancellationToken); + return await this.ExecuteCommitWithRetryAsync(serverRequest, cancellationToken); } - catch (Exception ex) + catch (Exception ex) when (ex is not OperationCanceledException) { DefaultTrace.TraceError($"Distributed transaction failed: {ex.Message}"); // await this.AbortTransactionAsync(cancellationToken); @@ -53,14 +71,62 @@ await DistributedTransactionCommitterUtils.ResolveCollectionRidsAsync( } } + private async Task ExecuteCommitWithRetryAsync( + DistributedTransactionServerRequest serverRequest, + CancellationToken cancellationToken) + { + int attempt = 0; + using (ITrace retryTrace = Trace.GetRootTrace("Distributed Transaction Commit", TraceComponent.Batch, TraceLevel.Info)) + { + while (true) + { + cancellationToken.ThrowIfCancellationRequested(); + + DistributedTransactionResponse response = await this.ExecuteCommitAsync(serverRequest, retryTrace, cancellationToken); + + if (!response.IsSuccessStatusCode && response.IsRetriable) + { + if (attempt >= DistributedTransactionCommitter.MaxIsRetriableRetryCount) + { + DefaultTrace.TraceWarning( + $"Distributed transaction isRetriable retry budget exhausted after {attempt} attempts " + + $"(StatusCode={response.StatusCode}). Returning last response."); + return response; + } + + DefaultTrace.TraceWarning( + $"Distributed transaction commit retriable (StatusCode={response.StatusCode}, " + + $"IsRetriable={response.IsRetriable}, attempt {attempt + 1}). " + + $"Retrying with idempotency token {serverRequest.IdempotencyToken}."); + response.Dispose(); + await this.delayProvider(this.GetRetryDelay(attempt++), cancellationToken); + continue; + } + + return response; + } + } + } + + private TimeSpan GetRetryDelay(int attempt) + { + const int maxExponent = 5; + int exponent = Math.Min(attempt, maxExponent); + double baseDelayMs = this.retryBaseDelay.TotalMilliseconds * Math.Pow(2, exponent); + // Jitter: uniform random to decorrelate concurrent clients and avoid synchronized retry storms. + double jitterDelay = baseDelayMs * this.jitter.NextDouble(); + return TimeSpan.FromMilliseconds((baseDelayMs * 0.5) + jitterDelay); + } + private async Task ExecuteCommitAsync( DistributedTransactionServerRequest serverRequest, + ITrace parentTrace, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (ITrace trace = Trace.GetRootTrace("Execute Distributed Transaction Commit", TraceComponent.Batch, TraceLevel.Info)) + using (ITrace attemptTrace = parentTrace.StartChild("Execute Distributed Transaction Commit", TraceComponent.Batch, TraceLevel.Info)) { - using (MemoryStream bodyStream = serverRequest.TransferBodyStream()) + using (MemoryStream bodyStream = serverRequest.CreateBodyStream()) { ResponseMessage responseMessage = await this.clientContext.ProcessResourceOperationStreamAsync( resourceUri: DistributedTransactionCommitter.GetResourceUri(), @@ -72,25 +138,25 @@ private async Task ExecuteCommitAsync( itemId: null, streamPayload: bodyStream, requestEnricher: requestMessage => DistributedTransactionCommitter.EnrichRequestMessage(requestMessage, serverRequest), - trace: trace, + trace: attemptTrace, cancellationToken: cancellationToken); - cancellationToken.ThrowIfCancellationRequested(); - - DistributedTransactionResponse response = await DistributedTransactionResponse.FromResponseMessageAsync( - responseMessage, - serverRequest, - this.clientContext.SerializerCore, - serverRequest.IdempotencyToken, - trace, - cancellationToken); - - DistributedTransactionCommitter.MergeSessionTokens( - response, - serverRequest, - this.clientContext.DocumentClient.sessionContainer); - - return response; + using (responseMessage) + { + DistributedTransactionResponse response = await DistributedTransactionResponse.FromResponseMessageAsync( + responseMessage, + serverRequest, + this.clientContext.SerializerCore, + parentTrace, + cancellationToken); + + DistributedTransactionCommitter.MergeSessionTokens( + response, + serverRequest, + this.clientContext.DocumentClient?.sessionContainer); + + return response; + } } } } diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionConstants.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionConstants.cs index 81896c0b8e..20a7dc0dc1 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionConstants.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionConstants.cs @@ -8,7 +8,25 @@ namespace Microsoft.Azure.Cosmos internal static class DistributedTransactionConstants { - public static bool IsDistributedTransactionRequest(OperationType operationType, ResourceType resourceType) + // Sub-status codes returned on the envelope response for distributed transactions. + // Source: dtx-sdk-response-status-codes.md — Part A, Section 1. + + /// 449/5352 — Coordinator race conflict (ETag contention on the ledger exhausted). + internal const int DtcCoordinatorRaceConflict = 5352; + + /// 429/3200 — Ledger RU throttled and coordinator exhausted its internal retry budget. + internal const int DtcLedgerThrottled = 3200; + + /// 500/5411 — Ledger infrastructure failure. + internal const int DtcLedgerFailure = 5411; + + /// 500/5412 — Account configuration failure. + internal const int DtcAccountConfigFailure = 5412; + + /// 500/5413 — Coordinator dispatch failure. + internal const int DtcDispatchFailure = 5413; + + internal static bool IsDistributedTransactionRequest(OperationType operationType, ResourceType resourceType) { return operationType == OperationType.CommitDistributedTransaction && resourceType == ResourceType.DistributedTransactionBatch; diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs index 6e32f3c9f6..cbf0e2c4cb 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionResponse.cs @@ -38,7 +38,7 @@ private DistributedTransactionResponse( CosmosSerializerCore serializer, ITrace trace, Guid idempotencyToken, - string serverDiagnostics = null) + bool isRetriable = false) { this.Headers = headers; this.StatusCode = statusCode; @@ -48,7 +48,7 @@ private DistributedTransactionResponse( this.SerializerCore = serializer; this.Trace = trace; this.IdempotencyToken = idempotencyToken; - this.ServerDiagnostics = serverDiagnostics; + this.IsRetriable = isRetriable; } /// @@ -111,7 +111,14 @@ public virtual DistributedTransactionOperationResult this[int index] /// /// Gets the number of operation results in the distributed transaction response. /// - public virtual int Count => this.results?.Count ?? 0; + public virtual int Count + { + get + { + this.ThrowIfDisposed(); + return this.results?.Count ?? 0; + } + } /// /// Gets the idempotency token associated with this distributed transaction. @@ -119,9 +126,9 @@ public virtual DistributedTransactionOperationResult this[int index] public virtual Guid IdempotencyToken { get; } /// - /// Gets the server-side diagnostic information for the transaction. + /// Gets a value indicating whether the transaction is safe to retry with the same idempotency token. /// - public virtual string ServerDiagnostics { get; } + public virtual bool IsRetriable { get; } internal virtual SubStatusCodes SubStatusCode { get; } @@ -137,6 +144,7 @@ public virtual DistributedTransactionOperationResult this[int index] /// An enumerator for the operation results. public virtual IEnumerator GetEnumerator() { + this.ThrowIfDisposed(); return this.results?.GetEnumerator() ?? ((IList)Array.Empty()).GetEnumerator(); } @@ -166,7 +174,28 @@ internal static async Task FromResponseMessageAs ResponseMessage responseMessage, DistributedTransactionServerRequest serverRequest, CosmosSerializerCore serializer, - Guid requestIdempotencyToken, + Guid idempotencyToken, + ITrace trace, + CancellationToken cancellationToken) + { + using (ITrace createResponseTrace = trace.StartChild("Create Distributed Transaction Response", TraceComponent.Batch, TraceLevel.Info)) + { + cancellationToken.ThrowIfCancellationRequested(); + + return await DistributedTransactionResponse.FromResponseMessageCoreAsync( + responseMessage, + serverRequest, + serializer, + idempotencyToken, + createResponseTrace, + cancellationToken); + } + } + + internal static async Task FromResponseMessageAsync( + ResponseMessage responseMessage, + DistributedTransactionServerRequest serverRequest, + CosmosSerializerCore serializer, ITrace trace, CancellationToken cancellationToken) { @@ -175,77 +204,94 @@ internal static async Task FromResponseMessageAs cancellationToken.ThrowIfCancellationRequested(); // Extract idempotency token from response headers, fallback to request token if not present - Guid idempotencyToken = GetIdempotencyTokenFromHeaders(responseMessage.Headers, requestIdempotencyToken); + Guid idempotencyToken = GetIdempotencyTokenFromHeaders(responseMessage.Headers, serverRequest.IdempotencyToken); + + return await DistributedTransactionResponse.FromResponseMessageCoreAsync( + responseMessage, + serverRequest, + serializer, + idempotencyToken, + createResponseTrace, + cancellationToken); + } + } - DistributedTransactionResponse response = null; - MemoryStream memoryStream = null; + private static async Task FromResponseMessageCoreAsync( + ResponseMessage responseMessage, + DistributedTransactionServerRequest serverRequest, + CosmosSerializerCore serializer, + Guid idempotencyToken, + ITrace trace, + CancellationToken cancellationToken) + { + DistributedTransactionResponse response = null; + MemoryStream memoryStream = null; - try + try + { + if (responseMessage.Content != null) { - if (responseMessage.Content != null) - { - Stream content = responseMessage.Content; + Stream content = responseMessage.Content; - // Ensure the stream is seekable - if (!content.CanSeek) - { - memoryStream = new MemoryStream(); - await responseMessage.Content.CopyToAsync(memoryStream); - memoryStream.Position = 0; - content = memoryStream; - } - - response = await PopulateFromJsonContentAsync( - content, - responseMessage, - serverRequest, - serializer, - idempotencyToken, - createResponseTrace, - cancellationToken); + // Ensure the stream is seekable + if (!content.CanSeek) + { + memoryStream = new MemoryStream(); + await responseMessage.Content.CopyToAsync(memoryStream); + memoryStream.Position = 0; + content = memoryStream; } - // If we couldn't parse JSON content or there was no content, create default response - response ??= new DistributedTransactionResponse( - responseMessage.StatusCode, - responseMessage.Headers.SubStatusCode, - responseMessage.ErrorMessage, - responseMessage.Headers, - serverRequest.Operations, + response = await PopulateFromJsonContentAsync( + content, + responseMessage, + serverRequest, serializer, - createResponseTrace, - idempotencyToken); - - // Validate results count matches operations count - if (response.results == null || response.results.Count != serverRequest.Operations.Count) - { - DefaultTrace.TraceWarning( - $"DTC response: result count ({response.results?.Count ?? 0}) differs from " + - $"operation count ({serverRequest.Operations.Count})."); + idempotencyToken, + trace, + cancellationToken); + } - if (responseMessage.IsSuccessStatusCode) - { - // Server should guarantee results count equals operations count on success - return new DistributedTransactionResponse( - HttpStatusCode.InternalServerError, - SubStatusCodes.Unknown, - ClientResources.InvalidServerResponse, - responseMessage.Headers, - serverRequest.Operations, - serializer, - createResponseTrace, - idempotencyToken); - } + // If we couldn't parse JSON content or there was no content, create default response + response ??= new DistributedTransactionResponse( + responseMessage.StatusCode, + responseMessage.Headers.SubStatusCode, + responseMessage.ErrorMessage, + responseMessage.Headers, + serverRequest.Operations, + serializer, + trace, + idempotencyToken); + + // Validate results count matches operations count + if (response.results == null || response.results.Count != serverRequest.Operations.Count) + { + DefaultTrace.TraceWarning( + $"DTC response: result count ({response.results?.Count ?? 0}) differs from " + + $"operation count ({serverRequest.Operations.Count})."); - response.CreateAndPopulateResults(serverRequest.Operations, createResponseTrace); + if (responseMessage.IsSuccessStatusCode) + { + // Server should guarantee results count equals operations count on success + return new DistributedTransactionResponse( + HttpStatusCode.InternalServerError, + SubStatusCodes.Unknown, + ClientResources.InvalidServerResponse, + responseMessage.Headers, + serverRequest.Operations, + serializer, + trace, + idempotencyToken); } - return response; - } - finally - { - memoryStream?.Dispose(); + response.CreateAndPopulateResults(serverRequest.Operations, trace); } + + return response; + } + finally + { + memoryStream?.Dispose(); } } @@ -302,16 +348,34 @@ private static async Task PopulateFromJsonConten CancellationToken cancellationToken) { List results = new List(); + bool isRetriable = false; + JsonDocument responseJson; try { - using (JsonDocument responseJson = await JsonDocument.ParseAsync(content, cancellationToken: cancellationToken)) + responseJson = await JsonDocument.ParseAsync(content, cancellationToken: cancellationToken); + } + catch (JsonException) + { + // Unparseable body — fall back to default response construction. + return null; + } + + using (responseJson) + { + JsonElement root = responseJson.RootElement; + + if (root.TryGetProperty("isRetriable", out JsonElement isRetriableElement) && + isRetriableElement.ValueKind == JsonValueKind.True) { - JsonElement root = responseJson.RootElement; + isRetriable = true; + } - // Parse operation results from "operationResponses" array - if (root.TryGetProperty("operationResponses", out JsonElement operationResponses) && - operationResponses.ValueKind == JsonValueKind.Array) + // Parse operation results from "operationResponses" array. + if (root.TryGetProperty("operationResponses", out JsonElement operationResponses) && + operationResponses.ValueKind == JsonValueKind.Array) + { + try { foreach (JsonElement operationElement in operationResponses.EnumerateArray()) { @@ -323,13 +387,12 @@ private static async Task PopulateFromJsonConten results.Add(operationResult); } } + catch (JsonException) + { + results.Clear(); + } } } - catch (JsonException) - { - // If JSON parsing fails, return null to fall back to default response - return null; - } HttpStatusCode finalStatusCode = responseMessage.StatusCode; SubStatusCodes finalSubStatusCode = responseMessage.Headers.SubStatusCode; @@ -357,7 +420,8 @@ private static async Task PopulateFromJsonConten serverRequest.Operations, serializer, trace, - idempotencyToken) + idempotencyToken, + isRetriable) { results = results }; diff --git a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionServerRequest.cs b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionServerRequest.cs index 6cbab1e5f8..50f8c9f314 100644 --- a/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionServerRequest.cs +++ b/Microsoft.Azure.Cosmos/src/DistributedTransaction/DistributedTransactionServerRequest.cs @@ -13,7 +13,7 @@ namespace Microsoft.Azure.Cosmos internal class DistributedTransactionServerRequest { private readonly CosmosSerializerCore serializerCore; - private MemoryStream bodyStream; + private byte[] serializedBody; private DistributedTransactionServerRequest( IReadOnlyList operations, @@ -26,7 +26,7 @@ private DistributedTransactionServerRequest( public IReadOnlyList Operations { get; } - public Guid IdempotencyToken { get; private set; } + public Guid IdempotencyToken { get; } public static async Task CreateAsync( IReadOnlyList operations, @@ -38,11 +38,21 @@ public static async Task CreateAsync( return request; } - public MemoryStream TransferBodyStream() + /// + /// Returns a new backed by the pre-serialized request bytes. + /// Each call returns an independent, non-writable stream positioned at offset zero so + /// that the caller can safely wrap it in a using block and dispose it without + /// affecting subsequent retry attempts. + /// + /// Body stream. + public MemoryStream CreateBodyStream() { - MemoryStream bodyStream = this.bodyStream; - this.bodyStream = null; - return bodyStream; + if (this.serializedBody == null) + { + throw new InvalidOperationException("Request body has not been initialized. Use CreateAsync to construct a request."); + } + + return new MemoryStream(this.serializedBody, writable: false); } private async Task CreateBodyStreamAsync(CancellationToken cancellationToken) @@ -53,7 +63,10 @@ private async Task CreateBodyStreamAsync(CancellationToken cancellationToken) operation.PartitionKeyJson ??= operation.PartitionKey.ToJsonString(); } - this.bodyStream = DistributedTransactionSerializer.SerializeRequest(this.Operations); + using (MemoryStream stream = DistributedTransactionSerializer.SerializeRequest(this.Operations)) + { + this.serializedBody = stream.ToArray(); + } } } } diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ClientRetryPolicyTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ClientRetryPolicyTests.cs index b595711070..29620c1bde 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ClientRetryPolicyTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/ClientRetryPolicyTests.cs @@ -642,8 +642,186 @@ await BackoffRetryUtility.ExecuteAsync( } } } - } - + } + + // ─── DTX (Distributed Transaction) retry tests ─────────────────────────────── + + [TestMethod] + public async Task DtxRequest_408_ShouldRetry() + { + const bool enableEndpointDiscovery = true; + using GlobalEndpointManager endpointManager = this.Initialize( + useMultipleWriteLocations: false, + enableEndpointDiscovery: enableEndpointDiscovery, + isPreferredLocationsListEmpty: false, + enforceSingleMasterSingleWriteLocation: true); + + ClientRetryPolicy policy = new ClientRetryPolicy(endpointManager, this.partitionKeyRangeLocationCache, new RetryOptions(), enableEndpointDiscovery, false); + DocumentServiceRequest request = ClientRetryPolicyTests.CreateDtxRequest(); + policy.OnBeforeSendRequest(request); + + ResponseMessage response = new ResponseMessage(HttpStatusCode.RequestTimeout); + ShouldRetryResult result = await policy.ShouldRetryAsync(response, CancellationToken.None); + + Assert.IsTrue(result.ShouldRetry, "DTX 408 must be retried — idempotency token guarantees safety."); + } + + [TestMethod] + public async Task DtxRequest_449_5352_ShouldRetry_WithZeroDelay() + { + const bool enableEndpointDiscovery = true; + using GlobalEndpointManager endpointManager = this.Initialize( + useMultipleWriteLocations: false, + enableEndpointDiscovery: enableEndpointDiscovery, + isPreferredLocationsListEmpty: false, + enforceSingleMasterSingleWriteLocation: true); + + ClientRetryPolicy policy = new ClientRetryPolicy(endpointManager, this.partitionKeyRangeLocationCache, new RetryOptions(), enableEndpointDiscovery, false); + DocumentServiceRequest request = ClientRetryPolicyTests.CreateDtxRequest(); + policy.OnBeforeSendRequest(request); + + ResponseMessage response = new ResponseMessage((HttpStatusCode)StatusCodes.RetryWith); + response.Headers.SubStatusCodeLiteral = DistributedTransactionConstants.DtcCoordinatorRaceConflict.ToString(); + + ShouldRetryResult result = await policy.ShouldRetryAsync(response, CancellationToken.None); + + Assert.IsTrue(result.ShouldRetry, "DTX 449/5352 coordinator race conflict must be retried."); + Assert.AreEqual(TimeSpan.Zero, result.BackoffTime, "When no Retry-After header is present, delay should be zero."); + } + + [TestMethod] + public async Task DtxRequest_449_5352_ShouldRetry_HonorsRetryAfterHeader() + { + const bool enableEndpointDiscovery = true; + TimeSpan serverRetryAfter = TimeSpan.FromMilliseconds(250); + using GlobalEndpointManager endpointManager = this.Initialize( + useMultipleWriteLocations: false, + enableEndpointDiscovery: enableEndpointDiscovery, + isPreferredLocationsListEmpty: false, + enforceSingleMasterSingleWriteLocation: true); + + ClientRetryPolicy policy = new ClientRetryPolicy(endpointManager, this.partitionKeyRangeLocationCache, new RetryOptions(), enableEndpointDiscovery, false); + DocumentServiceRequest request = ClientRetryPolicyTests.CreateDtxRequest(); + policy.OnBeforeSendRequest(request); + + ResponseMessage response = new ResponseMessage((HttpStatusCode)StatusCodes.RetryWith); + response.Headers.SubStatusCodeLiteral = DistributedTransactionConstants.DtcCoordinatorRaceConflict.ToString(); + response.Headers.RetryAfterLiteral = ((long)serverRetryAfter.TotalMilliseconds).ToString(); + + ShouldRetryResult result = await policy.ShouldRetryAsync(response, CancellationToken.None); + + Assert.IsTrue(result.ShouldRetry, "DTX 449/5352 must be retried."); + Assert.AreEqual(serverRetryAfter, result.BackoffTime, "Retry delay must honor the server's Retry-After header."); + } + + [DataTestMethod] + [DataRow(DistributedTransactionConstants.DtcLedgerFailure, DisplayName = "500/5411 LedgerFailure")] + [DataRow(DistributedTransactionConstants.DtcAccountConfigFailure, DisplayName = "500/5412 AccountConfigFailure")] + [DataRow(DistributedTransactionConstants.DtcDispatchFailure, DisplayName = "500/5413 DispatchFailure")] + public async Task DtxRequest_500_InfraFailure_ShouldRetry(int subStatusCode) + { + const bool enableEndpointDiscovery = true; + using GlobalEndpointManager endpointManager = this.Initialize( + useMultipleWriteLocations: false, + enableEndpointDiscovery: enableEndpointDiscovery, + isPreferredLocationsListEmpty: false, + enforceSingleMasterSingleWriteLocation: true); + + ClientRetryPolicy policy = new ClientRetryPolicy(endpointManager, this.partitionKeyRangeLocationCache, new RetryOptions(), enableEndpointDiscovery, false); + DocumentServiceRequest request = ClientRetryPolicyTests.CreateDtxRequest(); + policy.OnBeforeSendRequest(request); + + ResponseMessage response = new ResponseMessage(HttpStatusCode.InternalServerError); + response.Headers.SubStatusCodeLiteral = subStatusCode.ToString(); + + ShouldRetryResult result = await policy.ShouldRetryAsync(response, CancellationToken.None); + + Assert.IsTrue(result.ShouldRetry, $"DTX 500/{subStatusCode} transient infra failure must be retried."); + } + + [TestMethod] + public async Task NonDtxWriteRequest_408_ShouldNotRetry() + { + const bool enableEndpointDiscovery = true; + using GlobalEndpointManager endpointManager = this.Initialize( + useMultipleWriteLocations: false, + enableEndpointDiscovery: enableEndpointDiscovery, + isPreferredLocationsListEmpty: false, + enforceSingleMasterSingleWriteLocation: true); + + ClientRetryPolicy policy = new ClientRetryPolicy(endpointManager, this.partitionKeyRangeLocationCache, new RetryOptions(), enableEndpointDiscovery, false); + // Non-DTX write (e.g., a point Create) + DocumentServiceRequest request = this.CreateRequest(isReadRequest: false, isMasterResourceType: false); + policy.OnBeforeSendRequest(request); + + ResponseMessage response = new ResponseMessage(HttpStatusCode.RequestTimeout); + ShouldRetryResult result = await policy.ShouldRetryAsync(response, CancellationToken.None); + + Assert.IsFalse(result.ShouldRetry, "Non-DTX 408 must NOT be retried by ClientRetryPolicy (only marks endpoint unavailable)."); + } + + [DataTestMethod] + [DataRow(DistributedTransactionConstants.DtcLedgerFailure, DisplayName = "500/5411 LedgerFailure")] + [DataRow(DistributedTransactionConstants.DtcAccountConfigFailure, DisplayName = "500/5412 AccountConfigFailure")] + [DataRow(DistributedTransactionConstants.DtcDispatchFailure, DisplayName = "500/5413 DispatchFailure")] + public async Task NonDtxWriteRequest_500_DtcSubStatus_ShouldNotRetry(int subStatusCode) + { + const bool enableEndpointDiscovery = true; + using GlobalEndpointManager endpointManager = this.Initialize( + useMultipleWriteLocations: false, + enableEndpointDiscovery: enableEndpointDiscovery, + isPreferredLocationsListEmpty: false, + enforceSingleMasterSingleWriteLocation: true); + + ClientRetryPolicy policy = new ClientRetryPolicy(endpointManager, this.partitionKeyRangeLocationCache, new RetryOptions(), enableEndpointDiscovery, false); + // Non-DTX write — same sub-status codes must NOT trigger a retry. + DocumentServiceRequest request = this.CreateRequest(isReadRequest: false, isMasterResourceType: false); + policy.OnBeforeSendRequest(request); + + ResponseMessage response = new ResponseMessage(HttpStatusCode.InternalServerError); + response.Headers.SubStatusCodeLiteral = subStatusCode.ToString(); + + ShouldRetryResult result = await policy.ShouldRetryAsync(response, CancellationToken.None); + + Assert.IsFalse(result.ShouldRetry, $"Non-DTX write 500/{subStatusCode} must NOT be retried — only DTX writes with idempotency tokens are safe."); + } + + [TestMethod] + public async Task DtxRequest_ExhaustsRetryBudget_ReturnsNoRetry() + { + const bool enableEndpointDiscovery = true; + using GlobalEndpointManager endpointManager = this.Initialize( + useMultipleWriteLocations: false, + enableEndpointDiscovery: enableEndpointDiscovery, + isPreferredLocationsListEmpty: false, + enforceSingleMasterSingleWriteLocation: true); + + ClientRetryPolicy policy = new ClientRetryPolicy(endpointManager, this.partitionKeyRangeLocationCache, new RetryOptions(), enableEndpointDiscovery, false); + DocumentServiceRequest request = ClientRetryPolicyTests.CreateDtxRequest(); + policy.OnBeforeSendRequest(request); + + ResponseMessage response = new ResponseMessage(HttpStatusCode.RequestTimeout); + + // Exhaust the full DTX retry budget (MaxDtxRetryCount = 100). + for (int i = 0; i < 100; i++) + { + ShouldRetryResult retryResult = await policy.ShouldRetryAsync(response, CancellationToken.None); + Assert.IsTrue(retryResult.ShouldRetry, $"DTX 408 retry {i + 1} of 100 should be allowed."); + } + + // The 101st call must be denied. + ShouldRetryResult finalResult = await policy.ShouldRetryAsync(response, CancellationToken.None); + Assert.IsFalse(finalResult.ShouldRetry, "DTX retry budget is exhausted after 100 retries; the 101st must be denied."); + } + + private static DocumentServiceRequest CreateDtxRequest() + { + return DocumentServiceRequest.Create( + OperationType.CommitDistributedTransaction, + ResourceType.DistributedTransactionBatch, + AuthorizationTokenType.PrimaryMasterKey); + } + private static GlobalPartitionEndpointManagerCore.PartitionKeyRangeFailoverInfo GetPartitionKeyRangeFailoverInfoUsingReflection( GlobalPartitionEndpointManager globalPartitionEndpointManager, PartitionKeyRange pkRange, diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionCommitterTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionCommitterTests.cs index d6bf64dcbb..2c6a75826d 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionCommitterTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionCommitterTests.cs @@ -30,6 +30,9 @@ public class DistributedTransactionCommitterTests private static readonly string CollectionResourceId = ResourceId.NewDocumentCollectionId(42, 129).DocumentCollectionId.ToString(); + // Known-valid collection resource ID that passes ResourceId.Parse. + private const string TestCollectionResourceId = "ccZ1ANCszwk="; + [TestMethod] [Description("Verifies that when the DTC response carries a session token, the token is merged into the SessionContainer")] public async Task CommitTransactionAsync_MergesSessionTokensIntoSessionContainer() @@ -273,6 +276,540 @@ public async Task CommitTransactionAsync_MergesSessionTokens_OnFailureResponse() "Session token should still be merged even when the DTC response indicates a failure."); } + // ─── Retry / Spec-Compliance Tests ───────────────────────────────────── + + [TestMethod] + [Description("Verifies that a commit succeeds without retrying when the server returns a success response on the first attempt.")] + public async Task CommitTransaction_SucceedsOnFirstAttempt() + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + return Task.FromResult(CreateSuccessResponseMessage(operationCount: 1)); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + using (DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None)) + { + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + Assert.IsTrue(response.IsSuccessStatusCode); + Assert.IsFalse(response.IsRetriable); + Assert.AreEqual(1, callCount); + } + } + + [TestMethod] + [Description("Verifies that when the server responds with isRetriable:true, the committer retries and eventually succeeds.")] + public async Task CommitTransaction_RetriesOnRetriableResponse_ThenSucceeds() + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + if (callCount == 1) + { + return Task.FromResult(CreateRetriableErrorResponseMessage()); + } + + return Task.FromResult(CreateSuccessResponseMessage(operationCount: 1)); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + using (DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None)) + { + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + Assert.IsTrue(response.IsSuccessStatusCode); + Assert.AreEqual(2, callCount); + } + } + + [TestMethod] + [Description("Verifies that the committer retries on isRetriable:true responses until the cancellation token is cancelled (before the retry budget is exhausted).")] + public async Task CommitTransaction_RetriableResponse_RetriesUntilCancelledBeforeBudgetExhausted() + { + using (CancellationTokenSource cts = new CancellationTokenSource()) + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + if (callCount == 3) + { + cts.Cancel(); + } + + return Task.FromResult(CreateRetriableErrorResponseMessage()); + }); + + // Non-zero delay so Task.Delay honours the already-cancelled token. + DistributedTransactionCommitter committer = new DistributedTransactionCommitter( + CreateTestOperations(), + mockContext.Object, + TimeSpan.FromMilliseconds(1)); + + await Assert.ThrowsExceptionAsync( + () => committer.CommitTransactionAsync(cts.Token)); + + // Retries continue until the cancellation token fires (before exhausting the budget). + Assert.AreEqual(3, callCount); + } + } + + [TestMethod] + [Description("Verifies that the outer isRetriable retry loop returns the last response after exhausting the retry budget (MaxIsRetriableRetryCount).")] + public async Task CommitTransaction_ExhaustsIsRetriableRetryBudget_ReturnsLastResponse() + { + int callCount = 0; + List capturedDelays = new List(); + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + return Task.FromResult(CreateRetriableErrorResponseMessage()); + }); + + Func captureDelay = (delay, _) => + { + capturedDelays.Add(delay); + return Task.CompletedTask; + }; + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter( + CreateTestOperations(), + mockContext.Object, + retryBaseDelay: TimeSpan.Zero, + delayProvider: captureDelay); + + using (DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None)) + { + // MaxIsRetriableRetryCount (10) retries + 1 final call that hits the budget check = 11 total calls. + Assert.AreEqual(DistributedTransactionCommitter.MaxIsRetriableRetryCount + 1, callCount, + "Expected exactly MaxIsRetriableRetryCount retries plus one final call that triggers budget exhaustion."); + Assert.AreEqual(DistributedTransactionCommitter.MaxIsRetriableRetryCount, capturedDelays.Count, + "Delay provider must be called once per retry attempt."); + Assert.IsFalse(response.IsSuccessStatusCode, + "The returned response must be the last non-success response."); + Assert.IsTrue(response.IsRetriable, + "The returned response must still have IsRetriable=true (budget exhausted, not a new response)."); + } + } + + [TestMethod] + [Description("Verifies that a CosmosException thrown from the pipeline propagates immediately without triggering the outer retry loop, which only handles the isRetriable JSON body flag.")] + public async Task CommitTransaction_CosmosExceptionDuringRequest_PropagatesImmediately() + { + using (CancellationTokenSource cts = new CancellationTokenSource()) + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + cts.Cancel(); // Cancel while the request is in-flight. + return Task.FromException(CreateCosmosTimeoutException()); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter( + CreateTestOperations(), + mockContext.Object, + TimeSpan.FromMilliseconds(1)); + + // Status-code-based retries (including 408) are handled by ClientRetryPolicy inside + // the pipeline. CosmosExceptions thrown directly from ProcessResourceOperationStreamAsync + // propagate through the outer loop, which only handles the isRetriable JSON body flag. + CosmosException ex = await Assert.ThrowsExceptionAsync( + () => committer.CommitTransactionAsync(cts.Token)); + + Assert.AreEqual(HttpStatusCode.RequestTimeout, ex.StatusCode); + Assert.AreEqual(1, callCount); + } + } + + [TestMethod] + [Description("Verifies that a response without isRetriable:true is returned immediately without any retry attempt.")] + public async Task CommitTransaction_DoesNotRetryOnNonRetriableFailure() + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + return Task.FromResult(CreateNonRetriableErrorResponseMessage()); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + using (DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None)) + { + Assert.AreEqual(HttpStatusCode.BadRequest, response.StatusCode); + Assert.IsFalse(response.IsSuccessStatusCode); + Assert.IsFalse(response.IsRetriable); + Assert.AreEqual(1, callCount); + } + } + + [TestMethod] + [Description("Verifies that a generic 500 response body without isRetriable:true does not trigger an outer retry.")] + public async Task CommitTransaction_DoesNotRetryOnNonRetriableServerError() + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + return Task.FromResult( + new ResponseMessage(HttpStatusCode.InternalServerError) + { + Content = new MemoryStream(Encoding.UTF8.GetBytes("{}")) + }); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + using (DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None)) + { + Assert.AreEqual(HttpStatusCode.InternalServerError, response.StatusCode); + Assert.IsFalse(response.IsSuccessStatusCode); + Assert.IsFalse(response.IsRetriable); + Assert.AreEqual(1, callCount); + } + } + + [TestMethod] + [Description("Verifies that a pre-cancelled CancellationToken causes CommitTransactionAsync to throw immediately without issuing any network request.")] + public async Task CommitTransaction_RespectsCancellationToken_PreCancelled() + { + using (CancellationTokenSource cts = new CancellationTokenSource()) + { + cts.Cancel(); + + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => throw new InvalidOperationException("Should not be called on a pre-cancelled token.")); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + await Assert.ThrowsExceptionAsync( + () => committer.CommitTransactionAsync(cts.Token)); + + this.VerifyProcessResourceOperationCallCount(mockContext, Times.Never()); + } + } + + [TestMethod] + [Description("Verifies that cancelling the token during the retry delay causes OperationCanceledException to propagate rather than proceeding with the next attempt.")] + public async Task CommitTransaction_CancelledDuringRetryDelay_ThrowsOperationCanceledException() + { + using (CancellationTokenSource cts = new CancellationTokenSource()) + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + cts.Cancel(); // Cancel after the first call so the retry delay throws. + return Task.FromResult(CreateRetriableErrorResponseMessage()); + }); + + // Non-zero delay so the retry path enters Task.Delay + // the token is already cancelled synchronously in the callback, so it throws immediately. + DistributedTransactionCommitter committer = new DistributedTransactionCommitter( + CreateTestOperations(), + mockContext.Object, + TimeSpan.FromMilliseconds(500)); + + await Assert.ThrowsExceptionAsync( + () => committer.CommitTransactionAsync(cts.Token)); + + Assert.AreEqual(1, callCount); + } + } + + [TestMethod] + [Description("Verifies that the committer retries on multiple consecutive isRetriable responses and eventually returns the success response.")] + public async Task CommitTransaction_MultipleRetriesThenSuccessOnLastAttempt() + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + if (callCount <= 3) + { + return Task.FromResult(CreateRetriableErrorResponseMessage()); + } + + return Task.FromResult(CreateSuccessResponseMessage(operationCount: 1)); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + using (DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None)) + { + // 3 retriable failures + 1 success = 4 total calls. + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + Assert.IsTrue(response.IsSuccessStatusCode); + Assert.AreEqual(4, callCount); + } + } + + [TestMethod] + [Description("Verifies that a non-CosmosException thrown from the pipeline propagates immediately without retrying.")] + public async Task CommitTransaction_NonCosmosException_PropagatesImmediately() + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + return Task.FromException(new IOException("Network error")); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + IOException ex = await Assert.ThrowsExceptionAsync( + () => committer.CommitTransactionAsync(CancellationToken.None)); + + Assert.AreEqual("Network error", ex.Message); + Assert.AreEqual(1, callCount); + } + + [TestMethod] + [Description("Verifies that any CosmosException thrown from the pipeline propagates immediately without triggering the outer retry loop.")] + public async Task CommitTransaction_AnyCosmosException_PropagatesImmediately() + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + CosmosException notFound = new CosmosException( + "Not found", + HttpStatusCode.NotFound, + subStatusCode: 0, + activityId: null, + requestCharge: 0); + return Task.FromException(notFound); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + CosmosException ex = await Assert.ThrowsExceptionAsync( + () => committer.CommitTransactionAsync(CancellationToken.None)); + + Assert.AreEqual(HttpStatusCode.NotFound, ex.StatusCode); + Assert.AreEqual(1, callCount); + } + + [TestMethod] + [Description("Verifies that the same idempotency token is used on all retry attempts, which is required for safe replay of committed transactions.")] + public async Task CommitTransaction_SameIdempotencyTokenSentOnEveryRetryAttempt() + { + int callCount = 0; + List capturedTokens = new List(); + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperationWithEnricherCapture( + mockContext, + enricher => + { + RequestMessage request = new RequestMessage + { + ResourceType = ResourceType.DistributedTransactionBatch, + OperationType = OperationType.CommitDistributedTransaction, + }; + enricher(request); + capturedTokens.Add(request.Headers[HttpConstants.HttpHeaders.IdempotencyToken]); + }, + () => + { + callCount++; + return callCount < 3 + ? Task.FromResult(CreateRetriableErrorResponseMessage()) + : Task.FromResult(CreateSuccessResponseMessage(operationCount: 1)); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + using (DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None)) + { + Assert.AreEqual(3, callCount); + } + + Assert.AreEqual(3, capturedTokens.Count); + CollectionAssert.AllItemsAreNotNull(capturedTokens, "IdempotencyToken header must be set on every request."); + Assert.AreEqual(1, new HashSet(capturedTokens).Count, "The same idempotency token must be used on every retry attempt."); + } + + [TestMethod] + [Description("Verifies that a 449 response with a sub-status code other than DtcCoordinatorRaceConflict (5352) is not retried by the outer loop.")] + public async Task CommitTransaction_DoesNotRetryOn449WithNonRaceConflictSubStatus() + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + // 449 with sub-status 0 — not a DTx race conflict, must not be retried. + return Task.FromResult(CreateEmptyResponseMessage((HttpStatusCode)StatusCodes.RetryWith, subStatusCode: 0)); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + using (DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None)) + { + Assert.AreEqual((HttpStatusCode)StatusCodes.RetryWith, response.StatusCode); + Assert.AreEqual(1, callCount, "449 with non-5352 sub-status must not be retried."); + } + } + + [TestMethod] + [Description("Verifies that a generic 500 InternalServerError with a non-DTC sub-status code is not retried by the outer loop.")] + public async Task CommitTransaction_DoesNotRetryOn500WithNonDtcSubStatus() + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + return Task.FromResult(CreateEmptyResponseMessage(HttpStatusCode.InternalServerError, subStatusCode: 0)); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + using (DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None)) + { + Assert.AreEqual(HttpStatusCode.InternalServerError, response.StatusCode); + Assert.AreEqual(1, callCount, "Generic 500 with non-DTC sub-status must not be retried."); + } + } + + [DataTestMethod] + [Description("Verifies that DTC validation failure responses (400 with DTC-specific sub-status codes) are never retried by the outer loop.")] + [DataRow(5405, DisplayName = "400/5405 ParseFailure")] + [DataRow(5406, DisplayName = "400/5406 FeatureDisabled")] + [DataRow(5407, DisplayName = "400/5407 MaxOpsExceeded")] + [DataRow(5408, DisplayName = "400/5408 MissingIdempotencyToken")] + [DataRow(5409, DisplayName = "400/5409 InvalidAccountName")] + [DataRow(5410, DisplayName = "400/5410 InvalidOperation")] + public async Task CommitTransaction_DoesNotRetryOnValidationFailure400(int subStatusCode) + { + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + return Task.FromResult(CreateEmptyResponseMessage(HttpStatusCode.BadRequest, subStatusCode)); + }); + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter(CreateTestOperations(), mockContext.Object, TimeSpan.Zero); + + using (DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None)) + { + Assert.AreEqual(HttpStatusCode.BadRequest, response.StatusCode); + Assert.AreEqual(1, callCount, $"Validation failure 400/{subStatusCode} must not be retried."); + } + } + + [TestMethod] + [Description("Verifies that GetRetryDelay produces exponentially growing delays with a cap at maxExponent=5, and that each delay falls within the expected jitter range [0.5*base*2^n, 1.5*base*2^n].")] + public async Task GetRetryDelay_ExponentialBackoff_DelaysGrowAndCapCorrectly() + { + const int retryCount = 7; + TimeSpan baseDelay = TimeSpan.FromSeconds(1); + List capturedDelays = new List(); + + // Set up: retryCount retriable responses so we capture retryCount delay values. + int callCount = 0; + Mock mockContext = this.CreateMockClientContext(); + this.SetupProcessResourceOperation( + mockContext, + () => + { + callCount++; + return callCount <= retryCount + ? Task.FromResult(CreateRetriableErrorResponseMessage()) + : Task.FromResult(CreateSuccessResponseMessage(operationCount: 1)); + }); + + Func captureDelay = (delay, _) => + { + capturedDelays.Add(delay); + return Task.CompletedTask; + }; + + DistributedTransactionCommitter committer = new DistributedTransactionCommitter( + CreateTestOperations(), + mockContext.Object, + retryBaseDelay: baseDelay, + delayProvider: captureDelay); + + using (DistributedTransactionResponse response = await committer.CommitTransactionAsync(CancellationToken.None)) + { + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + } + + Assert.AreEqual(retryCount, capturedDelays.Count, "One delay per retry attempt."); + + const int maxExponent = 5; + for (int i = 0; i < capturedDelays.Count; i++) + { + int exponent = Math.Min(i, maxExponent); + double baseMs = baseDelay.TotalMilliseconds * Math.Pow(2, exponent); + double minMs = baseMs * 0.5; + double maxMs = baseMs * 1.5; + + Assert.IsTrue( + capturedDelays[i].TotalMilliseconds >= minMs && capturedDelays[i].TotalMilliseconds <= maxMs, + $"Attempt {i}: delay {capturedDelays[i].TotalMilliseconds:F0}ms must be in [{minMs:F0}, {maxMs:F0}]ms."); + } + + // Delays at attempt >= maxExponent should be at the same magnitude (capped exponent). + double cappedBase = baseDelay.TotalMilliseconds * Math.Pow(2, maxExponent); + Assert.IsTrue( + capturedDelays[maxExponent].TotalMilliseconds >= cappedBase * 0.5 + && capturedDelays[maxExponent].TotalMilliseconds <= cappedBase * 1.5, + "Delay at maxExponent must be capped."); + Assert.IsTrue( + capturedDelays[maxExponent + 1].TotalMilliseconds >= cappedBase * 0.5 + && capturedDelays[maxExponent + 1].TotalMilliseconds <= cappedBase * 1.5, + "Delay beyond maxExponent must still use the capped exponent, producing a similar magnitude."); + } + // ─── Helpers ─────────────────────────────────────────────────────────── private static string BuildDtcResponseJson( @@ -357,5 +894,156 @@ private Mock CreateMockContext( return mockContext; } + + // ─── Retry test helpers ──────────────────────────────────────────────── + + private Mock CreateMockClientContext() + { + Mock mockContext = new Mock(); + + mockContext.Setup(x => x.SerializerCore).Returns(MockCosmosUtil.Serializer); + + mockContext.Setup(x => x.GetCachedContainerPropertiesAsync( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(ContainerProperties.CreateWithResourceId(TestCollectionResourceId)); + + return mockContext; + } + + private void SetupProcessResourceOperation( + Mock mockContext, + Func> responseFactory) + { + mockContext + .Setup(c => c.ProcessResourceOperationStreamAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(responseFactory); + } + + private void SetupProcessResourceOperationWithEnricherCapture( + Mock mockContext, + Action> enricherCallback, + Func> responseFactory) + { + mockContext + .Setup(c => c.ProcessResourceOperationStreamAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ITrace, CancellationToken>( + (_, _, _, _, _, _, _, _, enricher, _, _) => enricherCallback(enricher)) + .Returns(responseFactory); + } + + private void VerifyProcessResourceOperationCallCount( + Mock mockContext, + Times times) + { + mockContext.Verify(c => c.ProcessResourceOperationStreamAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.IsAny(), + It.IsAny()), times); + } + + private static IReadOnlyList CreateTestOperations(int count = 1) + { + List operations = new List(count); + for (int i = 0; i < count; i++) + { + operations.Add(new DistributedTransactionOperation( + OperationType.Create, + i, + "testDb", + "testContainer", + Cosmos.PartitionKey.Null)); + } + + return operations; + } + + private static ResponseMessage CreateSuccessResponseMessage(int operationCount) + { + StringBuilder json = new StringBuilder(); + json.Append("{\"operationResponses\":["); + for (int i = 0; i < operationCount; i++) + { + if (i > 0) + { + json.Append(","); + } + + json.Append($"{{\"index\":{i},\"statuscode\":200,\"substatuscode\":0}}"); + } + + json.Append("]}"); + + return new ResponseMessage(HttpStatusCode.OK) + { + Content = new MemoryStream(Encoding.UTF8.GetBytes(json.ToString())) + }; + } + + private static ResponseMessage CreateRetriableErrorResponseMessage() + { + string json = "{\"isRetriable\":true}"; + return new ResponseMessage(HttpStatusCode.ServiceUnavailable) + { + Content = new MemoryStream(Encoding.UTF8.GetBytes(json)) + }; + } + + private static ResponseMessage CreateNonRetriableErrorResponseMessage() + { + return new ResponseMessage(HttpStatusCode.BadRequest) + { + Content = new MemoryStream(Encoding.UTF8.GetBytes("{}")) + }; + } + + private static CosmosException CreateCosmosTimeoutException() + { + return new CosmosException( + "Request timed out", + HttpStatusCode.RequestTimeout, + subStatusCode: 0, + activityId: null, + requestCharge: 0); + } + + /// Creates an empty-body response with the given status and sub-status codes. + private static ResponseMessage CreateEmptyResponseMessage(HttpStatusCode statusCode, int subStatusCode) + { + ResponseMessage message = new ResponseMessage(statusCode); + message.Headers.SubStatusCodeLiteral = subStatusCode.ToString(); + return message; + } } } diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionResponseTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionResponseTests.cs index 13453e21de..baedef2ea9 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionResponseTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionResponseTests.cs @@ -623,6 +623,133 @@ public async Task FromResponseMessage_OperationResult_SessionToken_DeserializesC "SessionToken must equal the value from the JSON 'sessionToken' field."); } + // IsRetriable parsing + + [TestMethod] + [Description("When the response body contains isRetriable:true, IsRetriable must be true.")] + public async Task FromResponseMessage_IsRetriableTrue_ReturnsTrue() + { + DistributedTransactionServerRequest serverRequest = await BuildServerRequestAsync(operationCount: 1); + + string json = @"{""isRetriable"":true,""operationResponses"":[{""index"":0,""statusCode"":503}]}"; + ResponseMessage responseMessage = BuildResponseMessage(HttpStatusCode.ServiceUnavailable, json); + + DistributedTransactionResponse response = await DistributedTransactionResponse.FromResponseMessageAsync( + responseMessage, + serverRequest, + MockCosmosUtil.Serializer, + Guid.NewGuid(), + NoOpTrace.Singleton, + CancellationToken.None); + + Assert.IsTrue(response.IsRetriable, "IsRetriable must be true when the JSON body contains isRetriable:true."); + } + + [TestMethod] + [Description("When the response body contains isRetriable:false, IsRetriable must be false.")] + public async Task FromResponseMessage_IsRetriableFalse_ReturnsFalse() + { + DistributedTransactionServerRequest serverRequest = await BuildServerRequestAsync(operationCount: 1); + + string json = @"{""isRetriable"":false,""operationResponses"":[{""index"":0,""statusCode"":503}]}"; + ResponseMessage responseMessage = BuildResponseMessage(HttpStatusCode.ServiceUnavailable, json); + + DistributedTransactionResponse response = await DistributedTransactionResponse.FromResponseMessageAsync( + responseMessage, + serverRequest, + MockCosmosUtil.Serializer, + Guid.NewGuid(), + NoOpTrace.Singleton, + CancellationToken.None); + + Assert.IsFalse(response.IsRetriable, "IsRetriable must be false when the JSON body contains isRetriable:false."); + } + + [TestMethod] + [Description("When the response body does not contain an isRetriable field, IsRetriable must default to false.")] + public async Task FromResponseMessage_IsRetriableAbsent_ReturnsFalse() + { + DistributedTransactionServerRequest serverRequest = await BuildServerRequestAsync(operationCount: 1); + + string json = @"{""operationResponses"":[{""index"":0,""statusCode"":503}]}"; + ResponseMessage responseMessage = BuildResponseMessage(HttpStatusCode.ServiceUnavailable, json); + + DistributedTransactionResponse response = await DistributedTransactionResponse.FromResponseMessageAsync( + responseMessage, + serverRequest, + MockCosmosUtil.Serializer, + Guid.NewGuid(), + NoOpTrace.Singleton, + CancellationToken.None); + + Assert.IsFalse(response.IsRetriable, "IsRetriable must be false when the JSON body does not contain an isRetriable field."); + } + + [TestMethod] + [Description("When isRetriable is a string (not a boolean), IsRetriable must be false — strict boolean parsing.")] + public async Task FromResponseMessage_IsRetriableStringValue_ReturnsFalse() + { + DistributedTransactionServerRequest serverRequest = await BuildServerRequestAsync(operationCount: 1); + + // "true" as a string is not a JSON boolean — the parsing must require JsonValueKind.True. + string json = @"{""isRetriable"":""true"",""operationResponses"":[{""index"":0,""statusCode"":503}]}"; + ResponseMessage responseMessage = BuildResponseMessage(HttpStatusCode.ServiceUnavailable, json); + + DistributedTransactionResponse response = await DistributedTransactionResponse.FromResponseMessageAsync( + responseMessage, + serverRequest, + MockCosmosUtil.Serializer, + Guid.NewGuid(), + NoOpTrace.Singleton, + CancellationToken.None); + + Assert.IsFalse(response.IsRetriable, "IsRetriable must be false when isRetriable is a string — only JSON boolean true is accepted."); + } + + // ThrowIfDisposed guards on Count and GetEnumerator + + [TestMethod] + [Description("Calling Count after Dispose() must throw ObjectDisposedException.")] + public async Task Count_AfterDispose_ThrowsObjectDisposedException() + { + DistributedTransactionServerRequest serverRequest = await BuildServerRequestAsync(operationCount: 1); + string json = @"{""operationResponses"":[{""index"":0,""statusCode"":201}]}"; + ResponseMessage responseMessage = BuildResponseMessage(HttpStatusCode.OK, json); + + DistributedTransactionResponse response = await DistributedTransactionResponse.FromResponseMessageAsync( + responseMessage, + serverRequest, + MockCosmosUtil.Serializer, + Guid.NewGuid(), + NoOpTrace.Singleton, + CancellationToken.None); + + response.Dispose(); + + Assert.ThrowsException(() => _ = response.Count); + } + + [TestMethod] + [Description("Calling GetEnumerator after Dispose() must throw ObjectDisposedException.")] + public async Task GetEnumerator_AfterDispose_ThrowsObjectDisposedException() + { + DistributedTransactionServerRequest serverRequest = await BuildServerRequestAsync(operationCount: 1); + string json = @"{""operationResponses"":[{""index"":0,""statusCode"":201}]}"; + ResponseMessage responseMessage = BuildResponseMessage(HttpStatusCode.OK, json); + + DistributedTransactionResponse response = await DistributedTransactionResponse.FromResponseMessageAsync( + responseMessage, + serverRequest, + MockCosmosUtil.Serializer, + Guid.NewGuid(), + NoOpTrace.Singleton, + CancellationToken.None); + + response.Dispose(); + + Assert.ThrowsException(() => response.GetEnumerator()); + } + // Helpers /// diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionServerRequestTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionServerRequestTests.cs new file mode 100644 index 0000000000..0fa16078d8 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/DistributedTransaction/DistributedTransactionServerRequestTests.cs @@ -0,0 +1,69 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Tests.DistributedTransaction +{ + using System.Collections.Generic; + using System.IO; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Tests; + using Microsoft.Azure.Documents; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using PartitionKey = Microsoft.Azure.Cosmos.PartitionKey; + + [TestClass] + public class DistributedTransactionServerRequestTests + { + [TestMethod] + [Description("Verifies that CreateBodyStream returns a new independent MemoryStream on each call, enabling safe retry — disposing one stream must not affect siblings, and all streams must contain identical serialized bytes.")] + public async Task CreateBodyStream_CalledMultipleTimes_ReturnsIndependentStreams() + { + DistributedTransactionServerRequest request = await DistributedTransactionServerRequest.CreateAsync( + CreateTestOperations(), + MockCosmosUtil.Serializer, + CancellationToken.None); + + using (MemoryStream stream1 = request.CreateBodyStream()) + using (MemoryStream stream2 = request.CreateBodyStream()) + { + Assert.AreNotSame(stream1, stream2, "Each call must return a new stream instance."); + Assert.AreEqual(0, stream1.Position, "stream1 must be positioned at offset 0."); + Assert.AreEqual(0, stream2.Position, "stream2 must be positioned at offset 0."); + Assert.IsTrue(stream1.Length > 0, "The serialized body must be non-empty."); + Assert.AreEqual(stream1.Length, stream2.Length, "Both streams must contain the same number of bytes."); + CollectionAssert.AreEqual( + stream1.ToArray(), + stream2.ToArray(), + "Both streams must contain identical serialized bytes."); + } + + // Obtain a third stream after the first two have been disposed. + using (MemoryStream stream3 = request.CreateBodyStream()) + { + Assert.IsTrue(stream3.CanRead, "A stream obtained after disposing siblings must still be readable."); + Assert.AreEqual(0, stream3.Position, "stream3 must be positioned at offset 0."); + } + } + + private static IReadOnlyList CreateTestOperations() + { + return new List + { + new DistributedTransactionOperation( + OperationType.Create, + operationIndex: 0, + database: "testDb", + container: "testContainer", + partitionKey: new PartitionKey("pk0")), + new DistributedTransactionOperation( + OperationType.Upsert, + operationIndex: 1, + database: "testDb", + container: "testContainer", + partitionKey: new PartitionKey("pk1")), + }; + } + } +}