diff --git a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs index 7f7a22b466..771059e5be 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs @@ -284,6 +284,32 @@ public static QueryDefinition ToQueryDefinition(this IQueryable query) throw new ArgumentException("ToQueryDefinition is only supported on Cosmos LINQ query operations", nameof(query)); } + /// + /// This extension method returns the query as an asynchronous enumerable. + /// + /// the type of object to query. + /// the IQueryable{T} to be converted. + /// An asynchronous enumerable to go through the items. + /// + /// This example shows how to get the query as an asynchronous enumerable. + /// + /// + /// linqQueryable = this.Container.GetItemLinqQueryable(); + /// IAsyncEnumerable asyncEnumerable = linqQueryable.Where(item => (item.taskNum < 100)).AsAsyncEnumerable(); + /// ]]> + /// + /// + public static IAsyncEnumerable AsAsyncEnumerable(this IQueryable query) + { + if (query is CosmosLinqQuery asyncEnumerable) + { + return asyncEnumerable; + } + + throw new ArgumentException("AsAsyncEnumerable is only supported on Cosmos LINQ query operations", nameof(query)); + } + /// /// This extension method gets the FeedIterator from LINQ IQueryable to execute query asynchronously. /// This will create the fresh new FeedIterator when called. diff --git a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs index 6676d096c9..d445eea3c2 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs @@ -22,7 +22,7 @@ namespace Microsoft.Azure.Cosmos.Linq /// This is the entry point for LINQ query creation/execution, it generate query provider, implements IOrderedQueryable. /// /// - internal sealed class CosmosLinqQuery : IDocumentQuery, IOrderedQueryable + internal sealed class CosmosLinqQuery : IDocumentQuery, IOrderedQueryable, IAsyncEnumerable { private readonly CosmosLinqQueryProvider queryProvider; private readonly Guid correlatedActivityId; @@ -109,7 +109,7 @@ public IEnumerator GetEnumerator() " use GetItemQueryIterator to execute asynchronously"); } - FeedIterator localFeedIterator = this.CreateFeedIterator(false, out ScalarOperationKind scalarOperationKind); + using FeedIterator localFeedIterator = this.CreateFeedIterator(false, out ScalarOperationKind scalarOperationKind); Debug.Assert( scalarOperationKind == ScalarOperationKind.None, "CosmosLinqQuery Assert!", @@ -128,6 +128,33 @@ public IEnumerator GetEnumerator() } } + /// + /// Retrieves an object that can iterate through the individual results of the query asynchronously. + /// + /// + /// This triggers an asynchronous multi-page load. + /// + /// Cancellation token + /// IEnumerator + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + using FeedIteratorInlineCore localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false, out ScalarOperationKind scalarOperationKind); + Debug.Assert( + scalarOperationKind == ScalarOperationKind.None, + "CosmosLinqQuery Assert!", + $"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'"); + + while (localFeedIterator.HasMoreResults) + { + FeedResponse response = await localFeedIterator.ReadNextAsync(cancellationToken); + + foreach (T item in response) + { + yield return item; + } + } + } + /// /// Synchronous Multi-Page load /// diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosItemLinqTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosItemLinqTests.cs index efffeae1cb..356b296165 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosItemLinqTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosItemLinqTests.cs @@ -141,6 +141,31 @@ public void LinqQueryToIteratorBlockTest(bool isStreamIterator) } } + [TestMethod] + public async Task LinqQueryToAsyncEnumerable() + { + ToDoActivity toDoActivity = ToDoActivity.CreateRandomToDoActivity(); + toDoActivity.taskNum = 20; + toDoActivity.id = "minTaskNum"; + await this.Container.CreateItemAsync(toDoActivity, new PartitionKey(toDoActivity.pk)); + toDoActivity.taskNum = 100; + toDoActivity.id = "maxTaskNum"; + await this.Container.CreateItemAsync(toDoActivity, new PartitionKey(toDoActivity.pk)); + + IAsyncEnumerable query = this.Container.GetItemLinqQueryable() + .OrderBy(p => p.cost) + .AsAsyncEnumerable(); + + int found = 0; + await foreach (ToDoActivity item in query) + { + Assert.IsNotNull(item); + ++found; + } + + Assert.AreEqual(2, found); + } + [TestMethod] [DataRow(false)] [DataRow(true)]