Skip to content

Commit 4008f82

Browse files
authored
Add LINQ Shuffle (#112173)
1 parent f810340 commit 4008f82

21 files changed

+982
-122
lines changed

src/libraries/System.Linq.AsyncEnumerable/ref/System.Linq.AsyncEnumerable.cs

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ public static partial class AsyncEnumerable
132132
public static System.Collections.Generic.IAsyncEnumerable<TResult> Select<TSource, TResult>(this System.Collections.Generic.IAsyncEnumerable<TSource> source, System.Func<TSource, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<TResult>> selector) { throw null; }
133133
public static System.Collections.Generic.IAsyncEnumerable<TResult> Select<TSource, TResult>(this System.Collections.Generic.IAsyncEnumerable<TSource> source, System.Func<TSource, TResult> selector) { throw null; }
134134
public static System.Threading.Tasks.ValueTask<bool> SequenceEqualAsync<TSource>(this System.Collections.Generic.IAsyncEnumerable<TSource> first, System.Collections.Generic.IAsyncEnumerable<TSource> second, System.Collections.Generic.IEqualityComparer<TSource>? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
135+
public static System.Collections.Generic.IAsyncEnumerable<TSource> Shuffle<TSource>(this System.Collections.Generic.IAsyncEnumerable<TSource> source) { throw null; }
135136
public static System.Threading.Tasks.ValueTask<TSource> SingleAsync<TSource>(this System.Collections.Generic.IAsyncEnumerable<TSource> source, System.Func<TSource, bool> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
136137
public static System.Threading.Tasks.ValueTask<TSource> SingleAsync<TSource>(this System.Collections.Generic.IAsyncEnumerable<TSource> source, System.Func<TSource, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<bool>> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
137138
public static System.Threading.Tasks.ValueTask<TSource> SingleAsync<TSource>(this System.Collections.Generic.IAsyncEnumerable<TSource> source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }

src/libraries/System.Linq.AsyncEnumerable/src/System.Linq.AsyncEnumerable.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
</PropertyGroup>
1313

1414
<ItemGroup>
15+
<Compile Include="System\Linq\Shuffle.cs" />
1516
<Compile Include="System\Linq\SkipLast.cs" />
1617
<Compile Include="System\Linq\SkipWhile.cs" />
1718
<Compile Include="System\Linq\Append.cs" />
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Collections.Generic;
5+
using System.Runtime.CompilerServices;
6+
using System.Threading;
7+
8+
namespace System.Linq
9+
{
10+
public static partial class AsyncEnumerable
11+
{
12+
#if !NET
13+
[ThreadStatic]
14+
private static Random? t_random;
15+
#endif
16+
17+
/// <summary>Shuffles the order of the elements of a sequence.</summary>
18+
/// <typeparam name="TSource">The type of the elements of <paramref name="source"/>.</typeparam>
19+
/// <param name="source">A sequence of values to shuffle.</param>
20+
/// <returns>A sequence whose elements correspond to those of the input sequence in randomized order.</returns>
21+
/// <remarks>Randomization is performed using a non-cryptographically-secure random number generator.</remarks>
22+
public static IAsyncEnumerable<TSource> Shuffle<TSource>(
23+
this IAsyncEnumerable<TSource> source)
24+
{
25+
ThrowHelper.ThrowIfNull(source);
26+
27+
return Impl(source, default);
28+
29+
static async IAsyncEnumerable<TSource> Impl(
30+
IAsyncEnumerable<TSource> source,
31+
[EnumeratorCancellation] CancellationToken cancellationToken)
32+
{
33+
TSource[] array = await source.ToArrayAsync(cancellationToken).ConfigureAwait(false);
34+
35+
#if NET
36+
Random.Shared.Shuffle(array);
37+
#else
38+
Random random = t_random ??= new Random(Environment.TickCount ^ Environment.CurrentManagedThreadId);
39+
int n = array.Length;
40+
for (int i = 0; i < n - 1; i++)
41+
{
42+
int j = random.Next(i, n);
43+
if (j != i)
44+
{
45+
TSource temp = array[i];
46+
array[i] = array[j];
47+
array[j] = temp;
48+
}
49+
}
50+
#endif
51+
52+
for (int i = 0; i < array.Length; i++)
53+
{
54+
yield return array[i];
55+
}
56+
}
57+
}
58+
}
59+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Collections.Generic;
5+
using System.Threading;
6+
using System.Threading.Tasks;
7+
using Xunit;
8+
9+
namespace System.Linq.Tests
10+
{
11+
public class ShuffleTests : AsyncEnumerableTests
12+
{
13+
[Fact]
14+
public void InvalidInputs_Throws()
15+
{
16+
AssertExtensions.Throws<ArgumentNullException>("source", () => AsyncEnumerable.Shuffle<int>(null));
17+
}
18+
19+
[Theory]
20+
[InlineData(new int[0])]
21+
[InlineData(new int[] { 1 })]
22+
[InlineData(new int[] { 2, 4, 8 })]
23+
[InlineData(new int[] { -1, 2, 5, 6, 7, 8 })]
24+
public async Task VariousValues_ContainsAllInputValues(int[] values)
25+
{
26+
foreach (IAsyncEnumerable<int> source in CreateSources(values))
27+
{
28+
int[] shuffled = await source.Shuffle().ToArrayAsync();
29+
Array.Sort(shuffled);
30+
Assert.Equal(values, shuffled);
31+
}
32+
}
33+
34+
[Fact]
35+
public async Task ToArrayAsync_ElementsAreRandomized()
36+
{
37+
// The chance that shuffling a thousand elements produces the same order twice is infinitesimal.
38+
int length = 1000;
39+
foreach (IAsyncEnumerable<int> source in CreateSources(Enumerable.Range(0, length).ToArray()))
40+
{
41+
int[] first = await source.Shuffle().ToArrayAsync();
42+
int[] second = await source.Shuffle().ToArrayAsync();
43+
Assert.Equal(length, first.Length);
44+
Assert.Equal(length, second.Length);
45+
Assert.NotEqual(first, second);
46+
}
47+
}
48+
49+
[Fact]
50+
public async Task Cancellation_Cancels()
51+
{
52+
IAsyncEnumerable<int> source = CreateSource(2, 4, 8, 16);
53+
await Assert.ThrowsAsync<OperationCanceledException>(async () =>
54+
{
55+
await ConsumeAsync(source.Shuffle().WithCancellation(new CancellationToken(true)));
56+
});
57+
}
58+
59+
[Fact]
60+
public async Task InterfaceCalls_ExpectedCounts()
61+
{
62+
TrackingAsyncEnumerable<int> source = CreateSource(2, 4, 8, 16).Track();
63+
await ConsumeAsync(source.Shuffle());
64+
Assert.Equal(5, source.MoveNextAsyncCount);
65+
Assert.Equal(4, source.CurrentCount);
66+
Assert.Equal(1, source.DisposeAsyncCount);
67+
}
68+
}
69+
}

src/libraries/System.Linq.AsyncEnumerable/tests/System.Linq.AsyncEnumerable.Tests.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
<Compile Include="ElementAtOrDefaultAsyncTests.cs" />
1111
<Compile Include="ElementAtAsyncTests.cs" />
1212
<Compile Include="GroupByTests.cs" />
13+
<Compile Include="ShuffleTests.cs" />
1314
<Compile Include="SingleOrDefaultAsyncTests.cs" />
1415
<Compile Include="SingleAsyncTests.cs" />
1516
<Compile Include="LastOrDefaultAsyncTests.cs" />

src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs

+1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ public static partial class Queryable
158158
public static System.Linq.IQueryable<TResult> Select<TSource, TResult>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, TResult>> selector) { throw null; }
159159
public static bool SequenceEqual<TSource>(this System.Linq.IQueryable<TSource> source1, System.Collections.Generic.IEnumerable<TSource> source2) { throw null; }
160160
public static bool SequenceEqual<TSource>(this System.Linq.IQueryable<TSource> source1, System.Collections.Generic.IEnumerable<TSource> source2, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
161+
public static System.Linq.IQueryable<TSource> Shuffle<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
161162
public static TSource? SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
162163
public static TSource? SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
163164
public static TSource SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate, TSource defaultValue) { throw null; }

src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs

+12
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,18 @@ public static bool SequenceEqual<TSource>(this IQueryable<TSource> source1, IEnu
15871587
Expression.Constant(comparer, typeof(IEqualityComparer<TSource>))));
15881588
}
15891589

1590+
[DynamicDependency("Shuffle`1", typeof(Enumerable))]
1591+
public static IQueryable<TSource> Shuffle<TSource>(this IQueryable<TSource> source)
1592+
{
1593+
ArgumentNullException.ThrowIfNull(source);
1594+
1595+
return source.Provider.CreateQuery<TSource>(
1596+
Expression.Call(
1597+
null,
1598+
new Func<IQueryable<TSource>, IQueryable<TSource>>(Shuffle).Method,
1599+
source.Expression));
1600+
}
1601+
15901602
[DynamicDependency("Any`1", typeof(Enumerable))]
15911603
public static bool Any<TSource>(this IQueryable<TSource> source)
15921604
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using Xunit;
5+
6+
namespace System.Linq.Tests
7+
{
8+
public class ShuffleTests : EnumerableBasedTests
9+
{
10+
[Fact]
11+
public void InvalidArguments()
12+
{
13+
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<string>)null).Shuffle());
14+
}
15+
16+
[Fact]
17+
public void ProducesAllElements()
18+
{
19+
int[] shuffled = Enumerable.Range(0, 1000).AsQueryable().Shuffle().ToArray();
20+
Array.Sort(shuffled);
21+
Assert.Equal(Enumerable.Range(0, shuffled.Length), shuffled);
22+
}
23+
24+
[Fact]
25+
public void ElementsAreRandomized()
26+
{
27+
// The chance that shuffling a thousand elements produces the same order twice is infinitesimal.
28+
const int Length = 1000;
29+
IQueryable<int> source = Enumerable.Range(0, Length).AsQueryable().Shuffle();
30+
int[] first = source.ToArray();
31+
int[] second = source.ToArray();
32+
Assert.Equal(Length, first.Length);
33+
Assert.Equal(Length, second.Length);
34+
Assert.NotEqual(first, second);
35+
}
36+
}
37+
}

src/libraries/System.Linq.Queryable/tests/System.Linq.Queryable.Tests.csproj

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
<Compile Include="JoinTests.cs" />
3232
<Compile Include="LastOrDefaultTests.cs" />
3333
<Compile Include="LastTests.cs" />
34+
<Compile Include="ShuffleTests.cs" />
3435
<Compile Include="RightJoinTests.cs" />
3536
<Compile Include="LongCountTests.cs" />
3637
<Compile Include="MaxTests.cs" />
@@ -62,7 +63,6 @@
6263
<Compile Include="UnionTests.cs" />
6364
<Compile Include="WhereTests.cs" />
6465
<Compile Include="ZipTests.cs" />
65-
<Compile Include="$(CommonTestPath)System\Linq\SkipTakeData.cs"
66-
Link="Common\System\Linq\SkipTakeData.cs" />
66+
<Compile Include="$(CommonTestPath)System\Linq\SkipTakeData.cs" Link="Common\System\Linq\SkipTakeData.cs" />
6767
</ItemGroup>
6868
</Project>

src/libraries/System.Linq/ref/System.Linq.cs

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
// Changes to this file must follow the https://aka.ms/api-review process.
55
// ------------------------------------------------------------------------------
66

7+
using System.Collections.Generic;
8+
79
namespace System.Linq
810
{
911
public static partial class Enumerable
@@ -172,6 +174,7 @@ public static System.Collections.Generic.IEnumerable<
172174
public static System.Collections.Generic.IEnumerable<TResult> Select<TSource, TResult>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TResult> selector) { throw null; }
173175
public static bool SequenceEqual<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second) { throw null; }
174176
public static bool SequenceEqual<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
177+
public static System.Collections.Generic.IEnumerable<TSource> Shuffle<TSource>(this System.Collections.Generic.IEnumerable<TSource> source) { throw null; }
175178
public static TSource? SingleOrDefault<TSource>(this System.Collections.Generic.IEnumerable<TSource> source) { throw null; }
176179
public static TSource SingleOrDefault<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, TSource defaultValue) { throw null; }
177180
public static TSource? SingleOrDefault<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, bool> predicate) { throw null; }

src/libraries/System.Linq/src/System.Linq.csproj

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
<Compile Include="System\Linq\Range.SpeedOpt.cs" />
5555
<Compile Include="System\Linq\Repeat.cs" />
5656
<Compile Include="System\Linq\Repeat.SpeedOpt.cs" />
57+
<Compile Include="System\Linq\Shuffle.SpeedOpt.cs" />
58+
<Compile Include="System\Linq\Shuffle.cs" />
5759
<Compile Include="System\Linq\Reverse.cs" />
5860
<Compile Include="System\Linq\Reverse.SpeedOpt.cs" />
5961
<Compile Include="System\Linq\RightJoin.cs" />

0 commit comments

Comments
 (0)