Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement IAsyncEnumerable on CosmosLinqQuery #4355

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,32 @@ public static QueryDefinition ToQueryDefinition<T>(this IQueryable<T> query)
throw new ArgumentException("ToQueryDefinition is only supported on Cosmos LINQ query operations", nameof(query));
}

/// <summary>
/// This extension method returns the query as an asynchronous enumerable.
/// </summary>
/// <typeparam name="T">the type of object to query.</typeparam>
/// <param name="query">the IQueryable{T} to be converted.</param>
/// <returns>An asynchronous enumerable to go through the items.</returns>
/// <example>
/// This example shows how to get the query as an asynchronous enumerable.
///
/// <code language="c#">
/// <![CDATA[
/// IOrderedQueryable<ToDoActivity> linqQueryable = this.Container.GetItemLinqQueryable<ToDoActivity>();
/// IAsyncEnumerable<ToDoActivity> asyncEnumerable = linqQueryable.Where(item => (item.taskNum < 100)).AsAsyncEnumerable();
/// ]]>
/// </code>
/// </example>
public static IAsyncEnumerable<T> AsAsyncEnumerable<T>(this IQueryable<T> query)
{
if (query is IAsyncEnumerable<T> asyncEnumerable)
onionhammer marked this conversation as resolved.
Show resolved Hide resolved
{
return asyncEnumerable;
}

throw new ArgumentException("AsAsyncEnumerable is only supported on Cosmos LINQ query operations", nameof(query));
}

/// <summary>
/// This extension method gets the FeedIterator from LINQ IQueryable to execute query asynchronously.
/// This will create the fresh new FeedIterator when called.
Expand Down
20 changes: 19 additions & 1 deletion Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace Microsoft.Azure.Cosmos.Linq
/// This is the entry point for LINQ query creation/execution, it generate query provider, implements IOrderedQueryable.
/// </summary>
/// <seealso cref="CosmosLinqQueryProvider"/>
internal sealed class CosmosLinqQuery<T> : IDocumentQuery<T>, IOrderedQueryable<T>
internal sealed class CosmosLinqQuery<T> : IDocumentQuery<T>, IOrderedQueryable<T>, IAsyncEnumerable<T>
{
private readonly CosmosLinqQueryProvider queryProvider;
private readonly Guid correlatedActivityId;
Expand Down Expand Up @@ -283,5 +283,23 @@ private FeedIteratorInlineCore<T> CreateFeedIterator(bool isContinuationExpected
this.responseFactory.CreateQueryFeedUserTypeResponse<T>),
this.container.ClientContext);
}

public async IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
onionhammer marked this conversation as resolved.
Show resolved Hide resolved
{
using FeedIteratorInlineCore<T> localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false, out ScalarOperationKind scalarOperationKind);
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

while (localFeedIterator.HasMoreResults)
{
FeedResponse<T> response = await localFeedIterator.ReadNextAsync(cancellationToken);
foreach (T item in response)
{
yield return item;
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,31 @@ public void LinqQueryToIteratorBlockTest(bool isStreamIterator)
}
}

[TestMethod]
public async Task LinqQueryToAsyncEnumerable()
{
ToDoActivity toDoActivity = ToDoActivity.CreateRandomToDoActivity();
toDoActivity.taskNum = 20;
toDoActivity.id = "minTaskNum";
await this.Container.CreateItemAsync(toDoActivity, new PartitionKey(toDoActivity.pk));
toDoActivity.taskNum = 100;
toDoActivity.id = "maxTaskNum";
await this.Container.CreateItemAsync(toDoActivity, new PartitionKey(toDoActivity.pk));

IAsyncEnumerable<ToDoActivity> query = this.Container.GetItemLinqQueryable<ToDoActivity>()
.OrderBy(p => p.cost)
.AsAsyncEnumerable();

int found = 0;
await foreach (ToDoActivity item in query)
{
Assert.IsNotNull(item);
++found;
}

Assert.IsTrue(found > 0);
onionhammer marked this conversation as resolved.
Show resolved Hide resolved
}

[TestMethod]
[DataRow(false)]
[DataRow(true)]
Expand Down