Skip to content

Commit 7afaa25

Browse files
Merge branch 'master' into copilot/fix-unconditional-dll-copying
2 parents ae02b3e + e824e60 commit 7afaa25

5 files changed

Lines changed: 183 additions & 2 deletions

File tree

Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,24 @@ public int GatewayModeMaxConnectionLimit
306306
/// <seealso cref="CosmosClientBuilder.WithRequestTimeout(TimeSpan)"/>
307307
public TimeSpan RequestTimeout { get; set; }
308308

309+
/// <summary>
310+
/// Gets or sets the request timeout for inference service operations (e.g., semantic reranking).
311+
/// The number specifies the time to wait for a response from the inference service before the request is cancelled.
312+
/// This is a single-attempt timeout with no retries.
313+
/// </summary>
314+
/// <value>Default value is 5 seconds.</value>
315+
/// <remarks>
316+
/// This timeout is specific to inference service operations and is separate from the standard <see cref="RequestTimeout"/>.
317+
/// If the request does not complete within the specified duration, a <see cref="CosmosException"/> with status 408 (Request Timeout) is thrown.
318+
/// No retries are attempted on timeout.
319+
/// </remarks>
320+
#if PREVIEW
321+
public
322+
#else
323+
internal
324+
#endif
325+
TimeSpan InferenceRequestTimeout { get; set; } = InferenceService.DefaultInferenceRequestTimeout;
326+
309327
/// <summary>
310328
/// The SDK does a background refresh based on the time interval set to refresh the token credentials.
311329
/// This avoids latency issues because the old token is used until the new token is retrieved.

Microsoft.Azure.Cosmos/src/Fluent/CosmosClientBuilder.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,26 @@ public CosmosClientBuilder WithRequestTimeout(TimeSpan requestTimeout)
387387
return this;
388388
}
389389

390+
/// <summary>
391+
/// Sets the request timeout for inference service operations (e.g., semantic reranking).
392+
/// This is a single-attempt timeout with no retries; if the request does not complete
393+
/// within the specified duration, a <see cref="CosmosException"/> with status 408 (Request Timeout) is thrown.
394+
/// </summary>
395+
/// <param name="inferenceRequestTimeout">A time to use as timeout for inference operations.</param>
396+
/// <value>Default value is 5 seconds.</value>
397+
/// <returns>The current <see cref="CosmosClientBuilder"/>.</returns>
398+
/// <seealso cref="CosmosClientOptions.InferenceRequestTimeout"/>
399+
#if PREVIEW
400+
public
401+
#else
402+
internal
403+
#endif
404+
CosmosClientBuilder WithInferenceRequestTimeout(TimeSpan inferenceRequestTimeout)
405+
{
406+
this.clientOptions.InferenceRequestTimeout = inferenceRequestTimeout;
407+
return this;
408+
}
409+
390410
/// <summary>
391411
/// Sets the connection mode to Direct. This is used by the client when connecting to the Azure Cosmos DB service.
392412
/// </summary>

Microsoft.Azure.Cosmos/src/Inference/InferenceService.cs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace Microsoft.Azure.Cosmos
1515
using System.Threading;
1616
using System.Threading.Tasks;
1717
using global::Azure.Core;
18+
using Microsoft.Azure.Cosmos.Resource.CosmosExceptions;
1819
using Microsoft.Azure.Documents;
1920
using Microsoft.Azure.Documents.Collections;
2021

@@ -32,9 +33,16 @@ internal class InferenceService : IDisposable
3233
private const string InferenceTokenPrefix = "Bearer ";
3334
private const int inferenceServiceDefaultMaxConnectionLimit = 50;
3435

36+
/// <summary>
37+
/// Default per-request timeout for inference requests. Referenced by
38+
/// <see cref="CosmosClientOptions.InferenceRequestTimeout"/>.
39+
/// </summary>
40+
internal static readonly TimeSpan DefaultInferenceRequestTimeout = TimeSpan.FromSeconds(5);
41+
3542
private readonly int inferenceServiceMaxConnectionLimit;
3643
private readonly string inferenceServiceBaseUrl;
3744
private readonly Uri inferenceEndpoint;
45+
private readonly TimeSpan inferenceRequestTimeout;
3846

3947
private HttpClient httpClient;
4048
private AuthorizationTokenProvider cosmosAuthorization;
@@ -59,6 +67,9 @@ public InferenceService(CosmosClient client)
5967
"AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_SERVICE_MAX_CONNECTION_LIMIT",
6068
inferenceServiceDefaultMaxConnectionLimit) ?? inferenceServiceDefaultMaxConnectionLimit;
6169

70+
Debug.Assert(client.ClientOptions != null, "ClientOptions should not be null");
71+
this.inferenceRequestTimeout = client.ClientOptions.InferenceRequestTimeout;
72+
6273
// Create and configure HttpClient for inference requests.
6374
HttpMessageHandler httpMessageHandler = CosmosHttpClientCore.CreateHttpClientHandler(
6475
gatewayModeMaxConnectionLimit: this.inferenceServiceMaxConnectionLimit,
@@ -95,6 +106,7 @@ public InferenceService(CosmosClient client)
95106
/// </summary>
96107
internal InferenceService(HttpMessageHandler messageHandler, Uri inferenceEndpoint, AuthorizationTokenProvider cosmosAuthorization)
97108
{
109+
this.inferenceRequestTimeout = InferenceService.DefaultInferenceRequestTimeout;
98110
this.httpClient = new HttpClient(messageHandler);
99111
this.CreateClientHelper(this.httpClient);
100112
this.inferenceEndpoint = inferenceEndpoint;
@@ -115,6 +127,8 @@ public async Task<SemanticRerankResult> SemanticRerankAsync(
115127
IDictionary<string, object> options = null,
116128
CancellationToken cancellationToken = default)
117129
{
130+
DateTime startDateTimeUtc = DateTime.UtcNow;
131+
118132
// Prepare HTTP request for semantic reranking.
119133
HttpRequestMessage message = new HttpRequestMessage(HttpMethod.Post, this.inferenceEndpoint);
120134
INameValueCollection additionalHeaders = new RequestNameValueCollection();
@@ -139,8 +153,29 @@ await this.cosmosAuthorization.AddAuthorizationHeaderAsync(
139153
Encoding.UTF8,
140154
RuntimeConstants.MediaTypes.Json);
141155

142-
// Send the request and check for success.
143-
HttpResponseMessage responseMessage = await this.httpClient.SendAsync(message, cancellationToken);
156+
// Enforce a single-attempt, no-retry timeout for the inference request.
157+
// HttpClient.Timeout is intentionally left unchanged; this linked CTS is the authoritative
158+
// per-request timeout for inference calls.
159+
using CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
160+
linkedCts.CancelAfter(this.inferenceRequestTimeout);
161+
162+
HttpResponseMessage responseMessage;
163+
try
164+
{
165+
responseMessage = await this.httpClient.SendAsync(message, linkedCts.Token);
166+
}
167+
catch (OperationCanceledException operationCanceledException) when (!cancellationToken.IsCancellationRequested)
168+
{
169+
// Timeout triggered by the linked CTS (not the caller's cancellationToken).
170+
string errorMessage = $"Inference Service Request Timeout. Start Time UTC:{startDateTimeUtc}; Total Duration:{(DateTime.UtcNow - startDateTimeUtc).TotalMilliseconds} Ms; Inference Request Timeout:{this.inferenceRequestTimeout.TotalMilliseconds} Ms; Activity id: {System.Diagnostics.Trace.CorrelationManager.ActivityId};";
171+
throw CosmosExceptionFactory.CreateRequestTimeoutException(
172+
message: errorMessage,
173+
headers: new Headers()
174+
{
175+
ActivityId = System.Diagnostics.Trace.CorrelationManager.ActivityId.ToString()
176+
},
177+
innerException: operationCanceledException);
178+
}
144179

145180
if (!responseMessage.IsSuccessStatusCode)
146181
{

Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetPreviewSDKAPI.net6.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,11 +392,30 @@
392392
"Attributes": [],
393393
"MethodInfo": "System.Nullable`1[System.Int32] ThroughputBucket;CanRead:True;CanWrite:True;System.Nullable`1[System.Int32] get_ThroughputBucket();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;Void set_ThroughputBucket(System.Nullable`1[System.Int32]);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
394394
},
395+
"System.TimeSpan get_InferenceRequestTimeout()[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": {
396+
"Type": "Method",
397+
"Attributes": [
398+
"CompilerGeneratedAttribute"
399+
],
400+
"MethodInfo": "System.TimeSpan get_InferenceRequestTimeout();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
401+
},
402+
"System.TimeSpan InferenceRequestTimeout": {
403+
"Type": "Property",
404+
"Attributes": [],
405+
"MethodInfo": "System.TimeSpan InferenceRequestTimeout;CanRead:True;CanWrite:True;System.TimeSpan get_InferenceRequestTimeout();IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;Void set_InferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
406+
},
395407
"Void set_EnableRemoteRegionPreferredForSessionRetry(Boolean)": {
396408
"Type": "Method",
397409
"Attributes": [],
398410
"MethodInfo": "Void set_EnableRemoteRegionPreferredForSessionRetry(Boolean);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
399411
},
412+
"Void set_InferenceRequestTimeout(System.TimeSpan)[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": {
413+
"Type": "Method",
414+
"Attributes": [
415+
"CompilerGeneratedAttribute"
416+
],
417+
"MethodInfo": "Void set_InferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
418+
},
400419
"Void set_ReadConsistencyStrategy(System.Nullable`1[Microsoft.Azure.Cosmos.ReadConsistencyStrategy])[System.Runtime.CompilerServices.CompilerGeneratedAttribute()]": {
401420
"Type": "Method",
402421
"Attributes": [
@@ -1232,6 +1251,11 @@
12321251
"Attributes": [],
12331252
"MethodInfo": "Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithEnableRemoteRegionPreferredForSessionRetry(Boolean);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
12341253
},
1254+
"Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithInferenceRequestTimeout(System.TimeSpan)": {
1255+
"Type": "Method",
1256+
"Attributes": [],
1257+
"MethodInfo": "Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithInferenceRequestTimeout(System.TimeSpan);IsAbstract:False;IsStatic:False;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
1258+
},
12351259
"Microsoft.Azure.Cosmos.Fluent.CosmosClientBuilder WithReadConsistencyStrategy(Microsoft.Azure.Cosmos.ReadConsistencyStrategy)": {
12361260
"Type": "Method",
12371261
"Attributes": [],

Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/InferenceServiceTests.cs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,60 @@ public async Task SemanticRerankAsync_SuccessResponse_ReturnsResult()
9494
Assert.AreEqual(0, result.RerankScores[0].Index);
9595
}
9696

97+
[TestMethod]
98+
public async Task SemanticRerankAsync_RequestExceedsInferenceTimeout_Throws408CosmosException()
99+
{
100+
// Handler delays for 10 seconds; the internal InferenceService ctor uses the
101+
// DefaultInferenceRequestTimeout (5 seconds), so the linked CTS should cancel first.
102+
DelayedMessageHandler delayedHandler = new DelayedMessageHandler(
103+
delay: TimeSpan.FromSeconds(10),
104+
statusCode: HttpStatusCode.OK,
105+
responseContent: "{}");
106+
107+
Mock<AuthorizationTokenProvider> mockAuth = InferenceServiceTests.CreateMockAuthorizationTokenProvider();
108+
109+
using InferenceService service = new InferenceService(delayedHandler, TestEndpoint, mockAuth.Object);
110+
111+
CosmosException exception = await Assert.ThrowsExceptionAsync<CosmosException>(
112+
() => service.SemanticRerankAsync(
113+
rerankContext: "test query",
114+
documents: new List<string> { "doc1", "doc2" }));
115+
116+
Assert.AreEqual(HttpStatusCode.RequestTimeout, exception.StatusCode);
117+
Assert.IsTrue(
118+
exception.Message.Contains("Inference Service Request Timeout"),
119+
$"Expected timeout message. Actual: {exception.Message}");
120+
}
121+
122+
[TestMethod]
123+
public async Task SemanticRerankAsync_UserCancellation_PropagatesOperationCanceledException()
124+
{
125+
// Handler delays long enough that user cancellation should fire first.
126+
DelayedMessageHandler delayedHandler = new DelayedMessageHandler(
127+
delay: TimeSpan.FromSeconds(10),
128+
statusCode: HttpStatusCode.OK,
129+
responseContent: "{}");
130+
131+
Mock<AuthorizationTokenProvider> mockAuth = InferenceServiceTests.CreateMockAuthorizationTokenProvider();
132+
133+
using InferenceService service = new InferenceService(delayedHandler, TestEndpoint, mockAuth.Object);
134+
using CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(200));
135+
136+
try
137+
{
138+
await service.SemanticRerankAsync(
139+
rerankContext: "test query",
140+
documents: new List<string> { "doc1", "doc2" },
141+
cancellationToken: cts.Token);
142+
Assert.Fail("Expected OperationCanceledException to propagate when the caller cancels.");
143+
}
144+
catch (OperationCanceledException)
145+
{
146+
// Expected: user cancellation should surface as OperationCanceledException (or its
147+
// TaskCanceledException subclass), not be swallowed into a timeout CosmosException.
148+
}
149+
}
150+
97151
private static Mock<AuthorizationTokenProvider> CreateMockAuthorizationTokenProvider()
98152
{
99153
Mock<AuthorizationTokenProvider> mockAuth = new Mock<AuthorizationTokenProvider>();
@@ -132,5 +186,35 @@ protected override Task<HttpResponseMessage> SendAsync(
132186
return Task.FromResult(response);
133187
}
134188
}
189+
190+
/// <summary>
191+
/// HttpMessageHandler that delays for a configurable duration before responding.
192+
/// Used to exercise the per-request inference timeout.
193+
/// </summary>
194+
private class DelayedMessageHandler : HttpMessageHandler
195+
{
196+
private readonly TimeSpan delay;
197+
private readonly HttpStatusCode statusCode;
198+
private readonly string responseContent;
199+
200+
public DelayedMessageHandler(TimeSpan delay, HttpStatusCode statusCode, string responseContent)
201+
{
202+
this.delay = delay;
203+
this.statusCode = statusCode;
204+
this.responseContent = responseContent;
205+
}
206+
207+
protected override async Task<HttpResponseMessage> SendAsync(
208+
HttpRequestMessage request,
209+
CancellationToken cancellationToken)
210+
{
211+
await Task.Delay(this.delay, cancellationToken);
212+
213+
return new HttpResponseMessage(this.statusCode)
214+
{
215+
Content = new StringContent(this.responseContent, Encoding.UTF8, "application/json")
216+
};
217+
}
218+
}
135219
}
136220
}

0 commit comments

Comments
 (0)