Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,26 +216,26 @@ For types that fit well in SIMD registers, the library replaces the scalar
comparing one pair at a time, an entire layer of the sorting network executes
in a few instructions.

Here is a simplified example for `byte` with AVX2 — all 27 elements fit in a
single `Vector256<byte>` (32 lanes, 5 unused):
Here is a simplified example for `byte` with AVX-512 VBMI — all 27 elements fit in a
single `Vector512<byte>` (64 lanes, 37 unused):

```csharp
// Load all elements into one 256-bit vector
var vec = LoadVector256(ref first); // [e0, e1, ..., e26, 0, 0, 0, 0, 0]
// Load all elements into one 512-bit vector
var vec = LoadVector512(ref first); // [e0, e1, ..., e26, pad...]

// For each of the 13 layers:
// 1. Shuffle: rearrange elements to pair up comparators
var shuffled = Vector256.Shuffle(vec, layerPermutation);
var shuffled = Avx512Vbmi.PermuteVar64x8(vec, layerPermutation);

// 2. Min/Max: compare all pairs simultaneously
var mins = Vector256.Min(vec, shuffled);
var maxs = Vector256.Max(vec, shuffled);
var mins = Vector512.Min(vec, shuffled);
var maxs = Vector512.Max(vec, shuffled);

// 3. Select: pick min or max for each position using a mask
vec = Vector256.ConditionalSelect(layerMask, maxs, mins);
vec = Vector512.ConditionalSelect(layerMask, maxs, mins);

// After all 13 layers, store the sorted vector back
StoreVector256(ref first, vec);
StoreVector512(ref first, vec);
```

Each layer becomes a small handful of SIMD instructions — **shuffle**,
Expand Down Expand Up @@ -265,12 +265,13 @@ compiles to a single CPU comparison instruction matching the BCL's internal
sort helpers; for custom types the JIT can devirtualize `CompareTo` on value
types, keeping the call nearly as cheap.

For `byte` and `sbyte`, the generator additionally emits SIMD vectorization
when available — AVX2 on x86 and AdvSimd (NEON) on ARM64. For sizes up to 32,
all elements fit in a single vector register (or two on ARM64), allowing each
network step to execute as a vectorized shuffle + min/max + blend operation.
On ARM64, SIMD extends up to 64 elements using up to four `Vector128<byte>`
registers with single-group TBL4 lookups.
For `byte` and `sbyte`, the generator emits SIMD vectorization when available.
On x86 with AVX-512 VBMI, all elements (sizes 8-64) fit in a single
`Vector512<byte>` register using `PermuteVar64x8` shuffles — the fastest path.
On x86 without VBMI, AVX2 is used as a fallback for sizes 8-32 with
`Vector256<byte>` and `Avx2.Shuffle` (vpshufb). On ARM64 AdvSimd (NEON), up to
four `Vector128<byte>` registers with single-group TBL4 lookups handle sizes up
to 64 elements.

For `int` and `uint`, AVX2 SIMD is emitted on x86 with four `Vector256<int>`
registers (8 elements each). Cross-vector shuffles use `PermuteVar8x32` with
Expand Down Expand Up @@ -411,14 +412,21 @@ inlines the call for value types, keeping overhead low:

### x86 AVX-512F (AMD EPYC 9V74, GitHub Actions)

On CPUs with AVX-512F (e.g., AMD EPYC, Intel Ice Lake+), additional SIMD paths
are available. For `byte` and `sbyte`, the AVX2 path uses optimized direct
`Avx2.Shuffle` intrinsics (vpshufb) avoiding the expensive cross-lane emulation:
On CPUs with AVX-512 VBMI (e.g., AMD Zen 4+, Intel Ice Lake+), `byte` and `sbyte`
use a single `Vector512<byte>` with `PermuteVar64x8` shuffles for all sizes:

| Type | ArraySort (27) | GeneratedSort (27) | Speedup |
|---|---|---|---|
| byte | 1,591 ns | 61 ns | **26x** |
| sbyte | 1,714 ns | 68 ns | **25x** |
| Type | Size | ArraySort | GeneratedSort | Speedup |
|---|---|---|---|---|
| byte | 27 | 1,250 ns | 53 ns | **24x** |
| byte | 28 | 1,415 ns | 54 ns | **26x** |
| byte | 32 | 1,516 ns | 54 ns | **28x** |
| byte | 34 | 1,759 ns | 64 ns | **27x** |
| sbyte | 27 | 1,355 ns | 57 ns | **24x** |
| sbyte | 28 | 1,495 ns | 58 ns | **26x** |
| sbyte | 32 | 1,598 ns | 58 ns | **28x** |
| sbyte | 38 | 2,160 ns | 68 ns | **32x** |

> On CPUs without AVX-512 VBMI, byte/sbyte sizes 8-32 fall back to AVX2.

For `int`, `uint`, and `float`, AVX-512F uses two `Vector512` registers with
`PermuteVar16x32x2` cross-vector shuffles:
Expand Down
97 changes: 82 additions & 15 deletions SortingNetworks.Generators/SimdX86Emitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,15 @@ internal static string GetGuardCondition(SpecialType specialType)

/// <summary>
/// Returns the SIMD guard condition string for the given element type and size.
/// For byte/sbyte sizes > 32, AVX-512 VBMI is required instead of AVX2.
/// For byte/sbyte, AVX-512 VBMI is the primary path for all sizes.
/// AVX2 is available as a fallback for sizes ≤ 32 via <see cref="CanEmitAvx2Fallback"/>.
/// </summary>
internal static string GetGuardCondition(SpecialType specialType, int size)
{
int elemBytes = ElementSize(specialType);
switch (elemBytes)
{
case 1:
return size > 32
? "System.Runtime.Intrinsics.X86.Avx512Vbmi.IsSupported"
: "System.Runtime.Intrinsics.X86.Avx2.IsSupported";
case 1: return "System.Runtime.Intrinsics.X86.Avx512Vbmi.IsSupported";
case 2: return "System.Runtime.Intrinsics.X86.Avx512BW.IsSupported";
case 4: return "System.Runtime.Intrinsics.X86.Avx2.IsSupported";
case 8: return "System.Runtime.Intrinsics.X86.Avx512F.IsSupported";
Expand All @@ -60,6 +58,9 @@ internal static string GetGuardCondition(SpecialType specialType, int size)
/// </summary>
internal static bool CanEmitAvx2Fallback(SpecialType specialType, int size)
{
// 8-bit types: AVX2 fallback for sizes 8-32 (single Vector256<byte>)
if (ElementSize(specialType) == 1 && size >= 8 && size <= 32)
return true;
// 16-bit types: AVX2 fallback for sizes 8-16 (single Vector256<ushort>)
if (ElementSize(specialType) == 2 && size >= 8 && size <= 16)
return true;
Expand Down Expand Up @@ -92,6 +93,8 @@ internal static string GetAvx2FallbackGuardCondition()
internal static (string MethodSource, string DispatchCode) EmitAvx2Fallback(int size, string typeName, SpecialType specialType, List<List<(int A, int B)>> steps)
{
if (!CanEmitAvx2Fallback(specialType, size)) return ("", "");
if (ElementSize(specialType) == 1)
return EmitByteAvx2(size, typeName, specialType, steps);
Comment thread
jonathanpeppers marked this conversation as resolved.
if (specialType == SpecialType.System_Double)
return EmitDoubleAvx2(size, steps);
return EmitShortAvx2(size, typeName, specialType, steps);
Expand All @@ -107,7 +110,7 @@ internal static (string MethodSource, string DispatchCode) Emit(int size, string

switch (elemBytes)
{
case 1: return EmitByte(size, typeName, specialType, steps);
case 1: return EmitByteAvx512Vbmi(size, typeName, specialType, steps);
case 2: return EmitShort(size, typeName, specialType, steps);
case 4: return EmitInt32(size, typeName, specialType, steps);
case 8: return EmitInt64(size, typeName, specialType, steps);
Expand Down Expand Up @@ -146,18 +149,18 @@ internal static (string MethodSource, string DispatchCode) Emit(int size, string
return steps;
}

// --- Byte (8-bit) types: single Vector256 for sizes ≤ 32, Vector512 for 33-64 ---
// --- Byte (8-bit) types: AVX2 fallback for sizes 8-32 using Vector256<byte> ---

private static (string, string) EmitByte(int size, string typeName, SpecialType specialType, List<List<(int A, int B)>> steps)
private static (string, string) EmitByteAvx2(int size, string typeName, SpecialType specialType, List<List<(int A, int B)>> steps)
{
if (size > 32) return EmitByteAvx512Vbmi(size, typeName, specialType, steps);
if (size > 32) return ("", "");

var sb = new StringBuilder();
string suffix = $"_{typeName}";

// Method signature
sb.AppendLine($" [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.AggressiveOptimization)]");
sb.AppendLine($" private static void SortSimd{size}{suffix}(System.Span<{typeName}> span)");
sb.AppendLine($" private static void SortSimdAvx2_{size}{suffix}(System.Span<{typeName}> span)");
sb.AppendLine(" {");
sb.AppendLine($" ref byte first = ref System.Runtime.CompilerServices.Unsafe.As<{typeName}, byte>(ref System.Runtime.InteropServices.MemoryMarshal.GetReference(span));");

Expand Down Expand Up @@ -266,14 +269,14 @@ private static (string, string) EmitByte(int size, string typeName, SpecialType
string dispatchSb =
$" if (System.Runtime.Intrinsics.X86.Avx2.IsSupported)\n" +
$" {{\n" +
$" SortSimd{size}{suffix}(span);\n" +
$" SortSimdAvx2_{size}{suffix}(span);\n" +
$" return;\n" +
$" }}\n";

return (sb.ToString(), dispatchSb);
}

// --- Byte (8-bit) types: single Vector512 for sizes 33-64 using AVX-512 VBMI ---
// --- Byte (8-bit) types: single Vector512 for all sizes 8-64 using AVX-512 VBMI ---

private static (string, string) EmitByteAvx512Vbmi(int size, string typeName, SpecialType specialType, List<List<(int A, int B)>> steps)
{
Expand All @@ -292,7 +295,7 @@ private static (string, string) EmitByteAvx512Vbmi(int size, string typeName, Sp
{
sb.AppendLine(" var vec = System.Runtime.CompilerServices.Unsafe.ReadUnaligned<System.Runtime.Intrinsics.Vector512<byte>>(ref first);");
}
else
else if (size > 32)
{
// Partial: read lower 256 bits directly, upper 256 bits with overlap
int hiReadOffset = size - 32;
Expand All @@ -319,6 +322,47 @@ private static (string, string) EmitByteAvx512Vbmi(int size, string typeName, Sp
sb.AppendLine($" System.Runtime.Intrinsics.Vector512.Create(lo, hiRaw),");
sb.AppendLine($" {FmtVec512Byte(loadPerm)});");
}
else if (size == 32)
{
// Full Vector256: zero-extend to Vector512
sb.AppendLine(" var vec = System.Runtime.Intrinsics.Vector512.Create(");
sb.AppendLine(" System.Runtime.CompilerServices.Unsafe.ReadUnaligned<System.Runtime.Intrinsics.Vector256<byte>>(ref first),");
sb.AppendLine(" System.Runtime.Intrinsics.Vector256<byte>.Zero);");
}
else if (size > 16)
{
// 17-31: overlapping Vector128 reads, combine to Vector256, zero-extend to Vector512
Comment thread
jonathanpeppers marked this conversation as resolved.
int hiReadOffset = size - 16;
sb.AppendLine(" var lo128 = System.Runtime.CompilerServices.Unsafe.ReadUnaligned<System.Runtime.Intrinsics.Vector128<byte>>(ref first);");
sb.AppendLine($" var hi128 = System.Runtime.CompilerServices.Unsafe.ReadUnaligned<System.Runtime.Intrinsics.Vector128<byte>>(ref System.Runtime.CompilerServices.Unsafe.Add(ref first, {hiReadOffset}));");
// Build permutation to rearrange into proper element order.
// lo128[j] = element j for j in [0..15]
// hi128[j] = element (hiReadOffset + j) for j in [0..15]
// Combined: lo128 at [0..15], hi128 at [16..31], zeros at [32..63]
// For j >= 16: element j is at hi128[j - hiReadOffset] → source index 16 + (j - hiReadOffset)
byte[] loadPerm = new byte[64];
for (int i = 0; i < 64; i++)
{
if (i < 16)
loadPerm[i] = (byte)i;
else if (i < size)
loadPerm[i] = (byte)(i + 32 - size);
else
loadPerm[i] = 0;
}
sb.AppendLine($" var vec = System.Runtime.Intrinsics.X86.Avx512Vbmi.PermuteVar64x8(");
sb.AppendLine($" System.Runtime.Intrinsics.Vector512.Create(System.Runtime.Intrinsics.Vector256.Create(lo128, hi128), System.Runtime.Intrinsics.Vector256<byte>.Zero),");
sb.AppendLine($" {FmtVec512Byte(loadPerm)});");
}
else
{
// 8-16: single Vector128 read, zero-extend to Vector512
sb.AppendLine(" var vec = System.Runtime.Intrinsics.Vector512.Create(");
sb.AppendLine(" System.Runtime.Intrinsics.Vector256.Create(");
sb.AppendLine(" System.Runtime.CompilerServices.Unsafe.ReadUnaligned<System.Runtime.Intrinsics.Vector128<byte>>(ref first),");
sb.AppendLine(" System.Runtime.Intrinsics.Vector128<byte>.Zero),");
sb.AppendLine(" System.Runtime.Intrinsics.Vector256<byte>.Zero);");
}

// Emit each step
for (int si = 0; si < steps.Count; si++)
Expand Down Expand Up @@ -352,7 +396,7 @@ private static (string, string) EmitByteAvx512Vbmi(int size, string typeName, Sp
{
sb.AppendLine(" vec.StoreUnsafe(ref first);");
}
else
else if (size > 32)
{
int hiReadOffset = size - 32;
// Build permutation to extract upper portion for overlapping store
Expand All @@ -368,6 +412,29 @@ private static (string, string) EmitByteAvx512Vbmi(int size, string typeName, Sp
sb.AppendLine($" .GetUpper().StoreUnsafe(ref System.Runtime.CompilerServices.Unsafe.Add(ref first, {hiReadOffset}));");
sb.AppendLine(" vec.GetLower().StoreUnsafe(ref first);");
}
else if (size == 32)
{
// Store lower Vector256
sb.AppendLine(" vec.GetLower().StoreUnsafe(ref first);");
}
else if (size > 16)
{
// 17-31: overlapping Vector128 stores
int hiReadOffset = size - 16;
// Extract elements [size-16, size) into lower Vector128 via permutation
byte[] hiStorePerm = new byte[64];
for (int i = 0; i < 64; i++) hiStorePerm[i] = 0;
for (int i = 0; i < 16; i++)
hiStorePerm[i] = (byte)(size - 16 + i);
sb.AppendLine($" System.Runtime.Intrinsics.X86.Avx512Vbmi.PermuteVar64x8(vec, {FmtVec512Byte(hiStorePerm)})");
sb.AppendLine($" .GetLower().GetLower().StoreUnsafe(ref System.Runtime.CompilerServices.Unsafe.Add(ref first, {hiReadOffset}));");
sb.AppendLine(" vec.GetLower().GetLower().StoreUnsafe(ref first);");
}
else
{
// 8-16: store lower Vector128
sb.AppendLine(" vec.GetLower().GetLower().StoreUnsafe(ref first);");
}

sb.AppendLine(" }");

Expand Down Expand Up @@ -1474,7 +1541,7 @@ private static int MaxElements(int elemBytes)
// Max elements we support with SIMD
switch (elemBytes)
{
case 1: return 64; // Vector256<byte> (≤32) or Vector512<byte> with VBMI (33-64)
case 1: return 64; // Vector512<byte> with VBMI (all sizes 8-64), AVX2 fallback for ≤32
case 2: return 64; // 2x Vector512<ushort> with PermuteVar32x16x2
case 4: return 64; // 8x Vector256<int>
case 8: return 32; // 4x Vector512<long> (double uses AVX2 fallback for >32)
Expand Down
6 changes: 6 additions & 0 deletions SortingNetworks.Tests/GeneratedSortTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,18 @@ private static void StressSort<T>(int size, Func<Random, T> generator, Action<T[
[Fact]
public void Sort_28Elements_Byte() => StressSort(28, rng => (byte)rng.Next(0, 256), a => GeneratedSorters.Sort(a.AsSpan()));

[Fact]
public void Sort_32Elements_Byte() => StressSort(32, rng => (byte)rng.Next(0, 256), a => GeneratedSorters.Sort(a.AsSpan()));

[Fact]
public void Sort_27Elements_SByte() => StressSort(27, rng => (sbyte)rng.Next(-128, 128), a => GeneratedSorters.Sort(a.AsSpan()));

[Fact]
public void Sort_28Elements_SByte() => StressSort(28, rng => (sbyte)rng.Next(-128, 128), a => GeneratedSorters.Sort(a.AsSpan()));

[Fact]
public void Sort_32Elements_SByte() => StressSort(32, rng => (sbyte)rng.Next(-128, 128), a => GeneratedSorters.Sort(a.AsSpan()));

[Fact]
public void Sort_27Elements_Short() => StressSort(27, rng => (short)rng.Next(-1000, 1000), a => GeneratedSorters.Sort(a.AsSpan()));

Expand Down
2 changes: 2 additions & 0 deletions SortingNetworks.Tests/GeneratedSorters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ namespace SortingNetworks.Tests;
[SortingNetwork(16, typeof(byte))]
[SortingNetwork(27, typeof(byte))]
[SortingNetwork(28, typeof(byte))]
[SortingNetwork(32, typeof(byte))]
[SortingNetwork(48, typeof(byte))]
[SortingNetwork(64, typeof(byte))]
[SortingNetwork(27, typeof(sbyte))]
[SortingNetwork(28, typeof(sbyte))]
[SortingNetwork(32, typeof(sbyte))]
[SortingNetwork(48, typeof(sbyte))]
[SortingNetwork(64, typeof(sbyte))]
[SortingNetwork(8, typeof(short))]
Expand Down
38 changes: 38 additions & 0 deletions SortingNetworks.Tests/GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ public partial class MySorter { }
[InlineData(64, "float")]
[InlineData(8, "byte")]
[InlineData(16, "byte")]
[InlineData(32, "byte")]
[InlineData(48, "byte")]
[InlineData(64, "byte")]
[InlineData(8, "ushort")]
Expand Down Expand Up @@ -473,6 +474,43 @@ public partial class MySorter {{ }}
Assert.Contains("Avx2.IsSupported", generatedSource);
}

[Theory]
[InlineData(8, "byte")]
[InlineData(16, "byte")]
[InlineData(28, "byte")]
[InlineData(32, "byte")]
[InlineData(8, "sbyte")]
[InlineData(16, "sbyte")]
[InlineData(28, "sbyte")]
[InlineData(32, "sbyte")]
public void SimdCode_8Bit_HasAvx2Fallback(int size, string typeName)
{
var source = $@"
using SortingNetworks;

[SortingNetwork({size}, typeof({typeName}))]
public partial class MySorter {{ }}
";
var compilation = SourceGeneratorDriver.CreateCompilation(source);
var (result, updatedCompilation) = SourceGeneratorDriver.RunGeneratorWithCompilation(compilation);

var errors = result.Diagnostics.Where(d => d.Severity == DiagnosticSeverity.Error).ToArray();
Assert.Empty(errors);

var compilationErrors = SourceGeneratorDriver.GetErrors(updatedCompilation);
Assert.Empty(compilationErrors);

// Verify both AVX-512 VBMI and AVX2 methods were generated
var generatedSource = result.GeneratedTrees
.Select(t => t.GetText().ToString())
.FirstOrDefault(s => s.Contains($"SortSimd{size}_{typeName}") && s.Contains($"SortSimdAvx2_{size}_{typeName}"));
Assert.NotNull(generatedSource);

// Verify dispatch cascades from AVX-512 VBMI to AVX2
Assert.Contains("Avx512Vbmi.IsSupported", generatedSource);
Assert.Contains("Avx2.IsSupported", generatedSource);
}

[Fact]
public void ComparerOverload_GeneratesCode()
{
Expand Down