diff --git a/README.md b/README.md index a140f16..f6c7475 100644 --- a/README.md +++ b/README.md @@ -216,26 +216,26 @@ For types that fit well in SIMD registers, the library replaces the scalar comparing one pair at a time, an entire layer of the sorting network executes in a few instructions. -Here is a simplified example for `byte` with AVX2 — all 27 elements fit in a -single `Vector256` (32 lanes, 5 unused): +Here is a simplified example for `byte` with AVX-512 VBMI — all 27 elements fit in a +single `Vector512` (64 lanes, 37 unused): ```csharp -// Load all elements into one 256-bit vector -var vec = LoadVector256(ref first); // [e0, e1, ..., e26, 0, 0, 0, 0, 0] +// Load all elements into one 512-bit vector +var vec = LoadVector512(ref first); // [e0, e1, ..., e26, pad...] // For each of the 13 layers: // 1. Shuffle: rearrange elements to pair up comparators -var shuffled = Vector256.Shuffle(vec, layerPermutation); +var shuffled = Avx512Vbmi.PermuteVar64x8(vec, layerPermutation); // 2. Min/Max: compare all pairs simultaneously -var mins = Vector256.Min(vec, shuffled); -var maxs = Vector256.Max(vec, shuffled); +var mins = Vector512.Min(vec, shuffled); +var maxs = Vector512.Max(vec, shuffled); // 3. Select: pick min or max for each position using a mask -vec = Vector256.ConditionalSelect(layerMask, maxs, mins); +vec = Vector512.ConditionalSelect(layerMask, maxs, mins); // After all 13 layers, store the sorted vector back -StoreVector256(ref first, vec); +StoreVector512(ref first, vec); ``` Each layer becomes a small handful of SIMD instructions — **shuffle**, @@ -265,12 +265,13 @@ compiles to a single CPU comparison instruction matching the BCL's internal sort helpers; for custom types the JIT can devirtualize `CompareTo` on value types, keeping the call nearly as cheap. -For `byte` and `sbyte`, the generator additionally emits SIMD vectorization -when available — AVX2 on x86 and AdvSimd (NEON) on ARM64. For sizes up to 32, -all elements fit in a single vector register (or two on ARM64), allowing each -network step to execute as a vectorized shuffle + min/max + blend operation. -On ARM64, SIMD extends up to 64 elements using up to four `Vector128` -registers with single-group TBL4 lookups. +For `byte` and `sbyte`, the generator emits SIMD vectorization when available. +On x86 with AVX-512 VBMI, all elements (sizes 8-64) fit in a single +`Vector512` register using `PermuteVar64x8` shuffles — the fastest path. +On x86 without VBMI, AVX2 is used as a fallback for sizes 8-32 with +`Vector256` and `Avx2.Shuffle` (vpshufb). On ARM64 AdvSimd (NEON), up to +four `Vector128` registers with single-group TBL4 lookups handle sizes up +to 64 elements. For `int` and `uint`, AVX2 SIMD is emitted on x86 with four `Vector256` registers (8 elements each). Cross-vector shuffles use `PermuteVar8x32` with @@ -411,14 +412,21 @@ inlines the call for value types, keeping overhead low: ### x86 AVX-512F (AMD EPYC 9V74, GitHub Actions) -On CPUs with AVX-512F (e.g., AMD EPYC, Intel Ice Lake+), additional SIMD paths -are available. For `byte` and `sbyte`, the AVX2 path uses optimized direct -`Avx2.Shuffle` intrinsics (vpshufb) avoiding the expensive cross-lane emulation: +On CPUs with AVX-512 VBMI (e.g., AMD Zen 4+, Intel Ice Lake+), `byte` and `sbyte` +use a single `Vector512` with `PermuteVar64x8` shuffles for all sizes: -| Type | ArraySort (27) | GeneratedSort (27) | Speedup | -|---|---|---|---| -| byte | 1,591 ns | 61 ns | **26x** | -| sbyte | 1,714 ns | 68 ns | **25x** | +| Type | Size | ArraySort | GeneratedSort | Speedup | +|---|---|---|---|---| +| byte | 27 | 1,250 ns | 53 ns | **24x** | +| byte | 28 | 1,415 ns | 54 ns | **26x** | +| byte | 32 | 1,516 ns | 54 ns | **28x** | +| byte | 34 | 1,759 ns | 64 ns | **27x** | +| sbyte | 27 | 1,355 ns | 57 ns | **24x** | +| sbyte | 28 | 1,495 ns | 58 ns | **26x** | +| sbyte | 32 | 1,598 ns | 58 ns | **28x** | +| sbyte | 38 | 2,160 ns | 68 ns | **32x** | + +> On CPUs without AVX-512 VBMI, byte/sbyte sizes 8-32 fall back to AVX2. For `int`, `uint`, and `float`, AVX-512F uses two `Vector512` registers with `PermuteVar16x32x2` cross-vector shuffles: diff --git a/SortingNetworks.Generators/SimdX86Emitter.cs b/SortingNetworks.Generators/SimdX86Emitter.cs index 9f71fbd..05b7c30 100644 --- a/SortingNetworks.Generators/SimdX86Emitter.cs +++ b/SortingNetworks.Generators/SimdX86Emitter.cs @@ -37,17 +37,15 @@ internal static string GetGuardCondition(SpecialType specialType) /// /// Returns the SIMD guard condition string for the given element type and size. - /// For byte/sbyte sizes > 32, AVX-512 VBMI is required instead of AVX2. + /// For byte/sbyte, AVX-512 VBMI is the primary path for all sizes. + /// AVX2 is available as a fallback for sizes ≤ 32 via . /// internal static string GetGuardCondition(SpecialType specialType, int size) { int elemBytes = ElementSize(specialType); switch (elemBytes) { - case 1: - return size > 32 - ? "System.Runtime.Intrinsics.X86.Avx512Vbmi.IsSupported" - : "System.Runtime.Intrinsics.X86.Avx2.IsSupported"; + case 1: return "System.Runtime.Intrinsics.X86.Avx512Vbmi.IsSupported"; case 2: return "System.Runtime.Intrinsics.X86.Avx512BW.IsSupported"; case 4: return "System.Runtime.Intrinsics.X86.Avx2.IsSupported"; case 8: return "System.Runtime.Intrinsics.X86.Avx512F.IsSupported"; @@ -60,6 +58,9 @@ internal static string GetGuardCondition(SpecialType specialType, int size) /// internal static bool CanEmitAvx2Fallback(SpecialType specialType, int size) { + // 8-bit types: AVX2 fallback for sizes 8-32 (single Vector256) + if (ElementSize(specialType) == 1 && size >= 8 && size <= 32) + return true; // 16-bit types: AVX2 fallback for sizes 8-16 (single Vector256) if (ElementSize(specialType) == 2 && size >= 8 && size <= 16) return true; @@ -92,6 +93,8 @@ internal static string GetAvx2FallbackGuardCondition() internal static (string MethodSource, string DispatchCode) EmitAvx2Fallback(int size, string typeName, SpecialType specialType, List> steps) { if (!CanEmitAvx2Fallback(specialType, size)) return ("", ""); + if (ElementSize(specialType) == 1) + return EmitByteAvx2(size, typeName, specialType, steps); if (specialType == SpecialType.System_Double) return EmitDoubleAvx2(size, steps); return EmitShortAvx2(size, typeName, specialType, steps); @@ -107,7 +110,7 @@ internal static (string MethodSource, string DispatchCode) Emit(int size, string switch (elemBytes) { - case 1: return EmitByte(size, typeName, specialType, steps); + case 1: return EmitByteAvx512Vbmi(size, typeName, specialType, steps); case 2: return EmitShort(size, typeName, specialType, steps); case 4: return EmitInt32(size, typeName, specialType, steps); case 8: return EmitInt64(size, typeName, specialType, steps); @@ -146,18 +149,18 @@ internal static (string MethodSource, string DispatchCode) Emit(int size, string return steps; } - // --- Byte (8-bit) types: single Vector256 for sizes ≤ 32, Vector512 for 33-64 --- + // --- Byte (8-bit) types: AVX2 fallback for sizes 8-32 using Vector256 --- - private static (string, string) EmitByte(int size, string typeName, SpecialType specialType, List> steps) + private static (string, string) EmitByteAvx2(int size, string typeName, SpecialType specialType, List> steps) { - if (size > 32) return EmitByteAvx512Vbmi(size, typeName, specialType, steps); + if (size > 32) return ("", ""); var sb = new StringBuilder(); string suffix = $"_{typeName}"; // Method signature sb.AppendLine($" [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.AggressiveOptimization)]"); - sb.AppendLine($" private static void SortSimd{size}{suffix}(System.Span<{typeName}> span)"); + sb.AppendLine($" private static void SortSimdAvx2_{size}{suffix}(System.Span<{typeName}> span)"); sb.AppendLine(" {"); sb.AppendLine($" ref byte first = ref System.Runtime.CompilerServices.Unsafe.As<{typeName}, byte>(ref System.Runtime.InteropServices.MemoryMarshal.GetReference(span));"); @@ -266,14 +269,14 @@ private static (string, string) EmitByte(int size, string typeName, SpecialType string dispatchSb = $" if (System.Runtime.Intrinsics.X86.Avx2.IsSupported)\n" + $" {{\n" + - $" SortSimd{size}{suffix}(span);\n" + + $" SortSimdAvx2_{size}{suffix}(span);\n" + $" return;\n" + $" }}\n"; return (sb.ToString(), dispatchSb); } - // --- Byte (8-bit) types: single Vector512 for sizes 33-64 using AVX-512 VBMI --- + // --- Byte (8-bit) types: single Vector512 for all sizes 8-64 using AVX-512 VBMI --- private static (string, string) EmitByteAvx512Vbmi(int size, string typeName, SpecialType specialType, List> steps) { @@ -292,7 +295,7 @@ private static (string, string) EmitByteAvx512Vbmi(int size, string typeName, Sp { sb.AppendLine(" var vec = System.Runtime.CompilerServices.Unsafe.ReadUnaligned>(ref first);"); } - else + else if (size > 32) { // Partial: read lower 256 bits directly, upper 256 bits with overlap int hiReadOffset = size - 32; @@ -319,6 +322,47 @@ private static (string, string) EmitByteAvx512Vbmi(int size, string typeName, Sp sb.AppendLine($" System.Runtime.Intrinsics.Vector512.Create(lo, hiRaw),"); sb.AppendLine($" {FmtVec512Byte(loadPerm)});"); } + else if (size == 32) + { + // Full Vector256: zero-extend to Vector512 + sb.AppendLine(" var vec = System.Runtime.Intrinsics.Vector512.Create("); + sb.AppendLine(" System.Runtime.CompilerServices.Unsafe.ReadUnaligned>(ref first),"); + sb.AppendLine(" System.Runtime.Intrinsics.Vector256.Zero);"); + } + else if (size > 16) + { + // 17-31: overlapping Vector128 reads, combine to Vector256, zero-extend to Vector512 + int hiReadOffset = size - 16; + sb.AppendLine(" var lo128 = System.Runtime.CompilerServices.Unsafe.ReadUnaligned>(ref first);"); + sb.AppendLine($" var hi128 = System.Runtime.CompilerServices.Unsafe.ReadUnaligned>(ref System.Runtime.CompilerServices.Unsafe.Add(ref first, {hiReadOffset}));"); + // Build permutation to rearrange into proper element order. + // lo128[j] = element j for j in [0..15] + // hi128[j] = element (hiReadOffset + j) for j in [0..15] + // Combined: lo128 at [0..15], hi128 at [16..31], zeros at [32..63] + // For j >= 16: element j is at hi128[j - hiReadOffset] → source index 16 + (j - hiReadOffset) + byte[] loadPerm = new byte[64]; + for (int i = 0; i < 64; i++) + { + if (i < 16) + loadPerm[i] = (byte)i; + else if (i < size) + loadPerm[i] = (byte)(i + 32 - size); + else + loadPerm[i] = 0; + } + sb.AppendLine($" var vec = System.Runtime.Intrinsics.X86.Avx512Vbmi.PermuteVar64x8("); + sb.AppendLine($" System.Runtime.Intrinsics.Vector512.Create(System.Runtime.Intrinsics.Vector256.Create(lo128, hi128), System.Runtime.Intrinsics.Vector256.Zero),"); + sb.AppendLine($" {FmtVec512Byte(loadPerm)});"); + } + else + { + // 8-16: single Vector128 read, zero-extend to Vector512 + sb.AppendLine(" var vec = System.Runtime.Intrinsics.Vector512.Create("); + sb.AppendLine(" System.Runtime.Intrinsics.Vector256.Create("); + sb.AppendLine(" System.Runtime.CompilerServices.Unsafe.ReadUnaligned>(ref first),"); + sb.AppendLine(" System.Runtime.Intrinsics.Vector128.Zero),"); + sb.AppendLine(" System.Runtime.Intrinsics.Vector256.Zero);"); + } // Emit each step for (int si = 0; si < steps.Count; si++) @@ -352,7 +396,7 @@ private static (string, string) EmitByteAvx512Vbmi(int size, string typeName, Sp { sb.AppendLine(" vec.StoreUnsafe(ref first);"); } - else + else if (size > 32) { int hiReadOffset = size - 32; // Build permutation to extract upper portion for overlapping store @@ -368,6 +412,29 @@ private static (string, string) EmitByteAvx512Vbmi(int size, string typeName, Sp sb.AppendLine($" .GetUpper().StoreUnsafe(ref System.Runtime.CompilerServices.Unsafe.Add(ref first, {hiReadOffset}));"); sb.AppendLine(" vec.GetLower().StoreUnsafe(ref first);"); } + else if (size == 32) + { + // Store lower Vector256 + sb.AppendLine(" vec.GetLower().StoreUnsafe(ref first);"); + } + else if (size > 16) + { + // 17-31: overlapping Vector128 stores + int hiReadOffset = size - 16; + // Extract elements [size-16, size) into lower Vector128 via permutation + byte[] hiStorePerm = new byte[64]; + for (int i = 0; i < 64; i++) hiStorePerm[i] = 0; + for (int i = 0; i < 16; i++) + hiStorePerm[i] = (byte)(size - 16 + i); + sb.AppendLine($" System.Runtime.Intrinsics.X86.Avx512Vbmi.PermuteVar64x8(vec, {FmtVec512Byte(hiStorePerm)})"); + sb.AppendLine($" .GetLower().GetLower().StoreUnsafe(ref System.Runtime.CompilerServices.Unsafe.Add(ref first, {hiReadOffset}));"); + sb.AppendLine(" vec.GetLower().GetLower().StoreUnsafe(ref first);"); + } + else + { + // 8-16: store lower Vector128 + sb.AppendLine(" vec.GetLower().GetLower().StoreUnsafe(ref first);"); + } sb.AppendLine(" }"); @@ -1474,7 +1541,7 @@ private static int MaxElements(int elemBytes) // Max elements we support with SIMD switch (elemBytes) { - case 1: return 64; // Vector256 (≤32) or Vector512 with VBMI (33-64) + case 1: return 64; // Vector512 with VBMI (all sizes 8-64), AVX2 fallback for ≤32 case 2: return 64; // 2x Vector512 with PermuteVar32x16x2 case 4: return 64; // 8x Vector256 case 8: return 32; // 4x Vector512 (double uses AVX2 fallback for >32) diff --git a/SortingNetworks.Tests/GeneratedSortTests.cs b/SortingNetworks.Tests/GeneratedSortTests.cs index e542a8b..5693935 100644 --- a/SortingNetworks.Tests/GeneratedSortTests.cs +++ b/SortingNetworks.Tests/GeneratedSortTests.cs @@ -78,12 +78,18 @@ private static void StressSort(int size, Func generator, Action StressSort(28, rng => (byte)rng.Next(0, 256), a => GeneratedSorters.Sort(a.AsSpan())); + [Fact] + public void Sort_32Elements_Byte() => StressSort(32, rng => (byte)rng.Next(0, 256), a => GeneratedSorters.Sort(a.AsSpan())); + [Fact] public void Sort_27Elements_SByte() => StressSort(27, rng => (sbyte)rng.Next(-128, 128), a => GeneratedSorters.Sort(a.AsSpan())); [Fact] public void Sort_28Elements_SByte() => StressSort(28, rng => (sbyte)rng.Next(-128, 128), a => GeneratedSorters.Sort(a.AsSpan())); + [Fact] + public void Sort_32Elements_SByte() => StressSort(32, rng => (sbyte)rng.Next(-128, 128), a => GeneratedSorters.Sort(a.AsSpan())); + [Fact] public void Sort_27Elements_Short() => StressSort(27, rng => (short)rng.Next(-1000, 1000), a => GeneratedSorters.Sort(a.AsSpan())); diff --git a/SortingNetworks.Tests/GeneratedSorters.cs b/SortingNetworks.Tests/GeneratedSorters.cs index d638410..743e6bb 100644 --- a/SortingNetworks.Tests/GeneratedSorters.cs +++ b/SortingNetworks.Tests/GeneratedSorters.cs @@ -6,10 +6,12 @@ namespace SortingNetworks.Tests; [SortingNetwork(16, typeof(byte))] [SortingNetwork(27, typeof(byte))] [SortingNetwork(28, typeof(byte))] +[SortingNetwork(32, typeof(byte))] [SortingNetwork(48, typeof(byte))] [SortingNetwork(64, typeof(byte))] [SortingNetwork(27, typeof(sbyte))] [SortingNetwork(28, typeof(sbyte))] +[SortingNetwork(32, typeof(sbyte))] [SortingNetwork(48, typeof(sbyte))] [SortingNetwork(64, typeof(sbyte))] [SortingNetwork(8, typeof(short))] diff --git a/SortingNetworks.Tests/GeneratorTests.cs b/SortingNetworks.Tests/GeneratorTests.cs index 8fd4e8c..87877bf 100644 --- a/SortingNetworks.Tests/GeneratorTests.cs +++ b/SortingNetworks.Tests/GeneratorTests.cs @@ -220,6 +220,7 @@ public partial class MySorter { } [InlineData(64, "float")] [InlineData(8, "byte")] [InlineData(16, "byte")] + [InlineData(32, "byte")] [InlineData(48, "byte")] [InlineData(64, "byte")] [InlineData(8, "ushort")] @@ -473,6 +474,43 @@ public partial class MySorter {{ }} Assert.Contains("Avx2.IsSupported", generatedSource); } + [Theory] + [InlineData(8, "byte")] + [InlineData(16, "byte")] + [InlineData(28, "byte")] + [InlineData(32, "byte")] + [InlineData(8, "sbyte")] + [InlineData(16, "sbyte")] + [InlineData(28, "sbyte")] + [InlineData(32, "sbyte")] + public void SimdCode_8Bit_HasAvx2Fallback(int size, string typeName) + { + var source = $@" +using SortingNetworks; + +[SortingNetwork({size}, typeof({typeName}))] +public partial class MySorter {{ }} +"; + var compilation = SourceGeneratorDriver.CreateCompilation(source); + var (result, updatedCompilation) = SourceGeneratorDriver.RunGeneratorWithCompilation(compilation); + + var errors = result.Diagnostics.Where(d => d.Severity == DiagnosticSeverity.Error).ToArray(); + Assert.Empty(errors); + + var compilationErrors = SourceGeneratorDriver.GetErrors(updatedCompilation); + Assert.Empty(compilationErrors); + + // Verify both AVX-512 VBMI and AVX2 methods were generated + var generatedSource = result.GeneratedTrees + .Select(t => t.GetText().ToString()) + .FirstOrDefault(s => s.Contains($"SortSimd{size}_{typeName}") && s.Contains($"SortSimdAvx2_{size}_{typeName}")); + Assert.NotNull(generatedSource); + + // Verify dispatch cascades from AVX-512 VBMI to AVX2 + Assert.Contains("Avx512Vbmi.IsSupported", generatedSource); + Assert.Contains("Avx2.IsSupported", generatedSource); + } + [Fact] public void ComparerOverload_GeneratesCode() {