Skip to content

Commit c19bec9

Browse files
committed
Improve: FP8 GEMM throughput on Skylake/Haswell + Granite Rapids E5M2 kernel
E5M2 cast (`cast/{skylake,haswell}.h`): - Rewrite `nk_e5m2x{16,8}_to_f32x{16,8}_*` as 3-op (cvtepu8 + slli 8 + cvtph_ps). E5M2 shares F16 bias, so `byte << 8` is the matching F16 encoding. Pairwise dot (`dot/{skylake,haswell}.h`): - E5M2 widened to 64/16-lane multi-chain inline unpack: Skylake 3.17-3.94×, Haswell 2.07-5.05× across D=100..768. - E4M3 unchanged (16-lane single-chain was already at the cast cost limit). Stateful GEMM (`dot/`, `dots/{skylake,haswell}.h`): - E5M2: byte-pack + new dtype-specific `nk_dot_e5m2x{64,32}_update_*` with two independent FMA chains folding into the single state accumulator. Skylake dots/angulars/euclideans 1.42-2.52×, Haswell up to 3.17×. - E4M3 Skylake: asymmetric F16-pack (A streamed as F32, B pre-cast and stored as F16, widened to F32 on load). 50% memory savings for the packed B; compute ~0.89-1.16× of baseline (accepted tradeoff). - E4M3 Haswell: byte-pack at depth=32, neutral 0.95-1.03×. New Granite Rapids E5M2 GEMM (`dots/graniteamx.h`, `spatials/graniteamx.h`): - Pack E5M2 → F16 via `byte << 8`, run TDPFP16PS over F16 tiles to F32. Beats Sapphire AMX BF16 path on E5M2 inputs (better intermediate precision, same throughput). Wired through `dispatch_e5m2.c` ahead of Sapphire AMX. Sapphire AMX (`dots/sapphireamx.h`): keeps icelake LUT cast helpers (empirical eval showed Genoa-Giesen path regressed; revert preserved). Spatials (`spatials/skylake.h`): E4M3 normalize_packed packed_value_type follows dots packed_value_type (f16) for matching norm offset. Bonus: collapse split string literals across JS/Python error messages (`javascript/numkong.c`, `python/{each,matrix,tensor}.c`). ULP: dots/spatials max_ulp ≤ 1 vs scalar reference (threshold 32).
1 parent 679f55f commit c19bec9

22 files changed

Lines changed: 921 additions & 135 deletions

bench/bench_cross_amx.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,17 @@ void bench_cross_amx() {
5858
run_euclideans_packed<f16_k>("euclideans_packed_f16_graniteamx", nk_dots_packed_size_f16_graniteamx,
5959
nk_dots_pack_f16_graniteamx, nk_euclideans_packed_f16_graniteamx);
6060
run_euclideans_symmetric<f16_k>("euclideans_symmetric_f16_graniteamx", nk_euclideans_symmetric_f16_graniteamx);
61+
62+
run_dots_packed<e5m2_k>("dots_packed_e5m2_graniteamx", nk_dots_packed_size_e5m2_graniteamx,
63+
nk_dots_pack_e5m2_graniteamx, nk_dots_packed_e5m2_graniteamx);
64+
run_dots_symmetric<e5m2_k>("dots_symmetric_e5m2_graniteamx", nk_dots_symmetric_e5m2_graniteamx);
65+
66+
run_angulars_packed<e5m2_k>("angulars_packed_e5m2_graniteamx", nk_dots_packed_size_e5m2_graniteamx,
67+
nk_dots_pack_e5m2_graniteamx, nk_angulars_packed_e5m2_graniteamx);
68+
run_angulars_symmetric<e5m2_k>("angulars_symmetric_e5m2_graniteamx", nk_angulars_symmetric_e5m2_graniteamx);
69+
70+
run_euclideans_packed<e5m2_k>("euclideans_packed_e5m2_graniteamx", nk_dots_packed_size_e5m2_graniteamx,
71+
nk_dots_pack_e5m2_graniteamx, nk_euclideans_packed_e5m2_graniteamx);
72+
run_euclideans_symmetric<e5m2_k>("euclideans_symmetric_e5m2_graniteamx", nk_euclideans_symmetric_e5m2_graniteamx);
6173
#endif
6274
}

c/dispatch_e5m2.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,29 @@ void nk_dispatch_e5m2_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
113113
default: break;
114114
}
115115
#endif
116+
#if NK_TARGET_GRANITEAMX
117+
if (v & nk_cap_graniteamx_k) switch (k) {
118+
case nk_kernel_dots_packed_size_k:
119+
*m = (m_t)&nk_dots_packed_size_e5m2_graniteamx, *c = nk_cap_graniteamx_k;
120+
return;
121+
case nk_kernel_dots_pack_k: *m = (m_t)&nk_dots_pack_e5m2_graniteamx, *c = nk_cap_graniteamx_k; return;
122+
case nk_kernel_dots_packed_k: *m = (m_t)&nk_dots_packed_e5m2_graniteamx, *c = nk_cap_graniteamx_k; return;
123+
case nk_kernel_angulars_packed_k:
124+
*m = (m_t)&nk_angulars_packed_e5m2_graniteamx, *c = nk_cap_graniteamx_k;
125+
return;
126+
case nk_kernel_euclideans_packed_k:
127+
*m = (m_t)&nk_euclideans_packed_e5m2_graniteamx, *c = nk_cap_graniteamx_k;
128+
return;
129+
case nk_kernel_dots_symmetric_k: *m = (m_t)&nk_dots_symmetric_e5m2_graniteamx, *c = nk_cap_graniteamx_k; return;
130+
case nk_kernel_angulars_symmetric_k:
131+
*m = (m_t)&nk_angulars_symmetric_e5m2_graniteamx, *c = nk_cap_graniteamx_k;
132+
return;
133+
case nk_kernel_euclideans_symmetric_k:
134+
*m = (m_t)&nk_euclideans_symmetric_e5m2_graniteamx, *c = nk_cap_graniteamx_k;
135+
return;
136+
default: break;
137+
}
138+
#endif
116139
#if NK_TARGET_SAPPHIREAMX
117140
if (v & nk_cap_sapphireamx_k) switch (k) {
118141
case nk_kernel_dots_packed_size_k:

include/numkong/cast/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ NEON backend uses `vreinterpretq_u16_u8` + `vzip` for zero-extension; Haswell us
9393
`nk_f16_to_f32_haswell`, `nk_f32_to_f16_haswell` use the F16C extension instructions `VCVTPH2PS` / `VCVTPS2PH` — single-instruction conversion of 8 elements with correct denormal handling, NaN propagation, and RNE rounding.
9494
The serial fallback (`nk_f16_to_f32_serial`) must handle denormals via explicit exponent/mantissa extraction and conditional re-normalization — ~15 integer ops per element vs 1 instruction with F16C.
9595
AVX-512 (`nk_cast_skylake`) doubles throughput to 16 elements per instruction.
96+
F16C also unlocks a cheaper FP8 → F32 path that bypasses i32-lane bit math: `nk_e5m2x16_to_f32x16_skylake_` and `nk_e5m2x8_to_f32x8_haswell_` widen u8 → u16 and left-shift by 8 (E5M2 shares F16's bias 15, so the result is a bit-exact F16 encoding of every input including subnormals and NaN), then feed `VCVTPH2PS` — three ops total.
97+
E4M3 can't use a plain shift (bias 7 vs 15), but the Giesen-style fake-F16 `((byte & 0x7F) << 7) | ((byte & 0x80) << 8)` gives an F16 whose value differs from the E4M3 magnitude by exactly 2⁸; `nk_e4m3x16_to_f32x16_skylake_` and `nk_e4m3x8_to_f32x8_haswell_` widen through `VCVTPH2PS`, multiply by 256 in F32 to correct, and blend in F32 NaN for the lone `|byte|==0x7F` encoding.
98+
For E4M3 GEMM specifically, `nk_e4m3x16_to_f16x16_skylake_` produces TRUE F16 (bias-corrected, with a small subnormal LUT and NaN blend) so the packed buffer stores 2 bytes/element instead of 4 — the inner loop reads F16 and widens to F32 once per B-load, trading ~10% compute for 50% pack memory.
9699

97100
## Performance
98101

include/numkong/cast/haswell.h

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -194,31 +194,14 @@ NK_INTERNAL __m256 nk_e4m3x8_to_f32x8_haswell_(__m128i e4m3_i8x8) {
194194
return _mm256_mul_ps(fake_f32x8, _mm256_set1_ps(256.0f));
195195
}
196196

197-
/** @brief Convert 8x e5m2 → 8x f32 via bit manipulation (AVX2).
198-
* E5M2 format: S EEEEE MM (bias=15). F32: sign<<31, (exp+112)<<23, mant<<21.
199-
* Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁵⁾ × 2⁻² = mantissa ÷ 65536. */
197+
/** @brief Convert 8x e5m2 → 8x f32 via free-shift widen (AVX2 + F16C).
198+
* E5M2 shares F16's exponent bias (15): `(byte << 8)` is the matching F16 bit
199+
* pattern for every E5M2 value (normals, subnormals, zero, ±Inf, NaN — all
200+
* bit-exact). Widen u8 → u16, shift, then VCVTPH2PS to F32. Three ops total. */
200201
NK_INTERNAL __m256 nk_e5m2x8_to_f32x8_haswell_(__m128i e5m2_i8x8) {
201-
__m256i e5m2_i32x8 = _mm256_cvtepu8_epi32(e5m2_i8x8);
202-
203-
// Extract fields
204-
__m256i exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(e5m2_i32x8, 2), _mm256_set1_epi32(0x1F));
205-
__m256i mant_i32x8 = _mm256_and_si256(e5m2_i32x8, _mm256_set1_epi32(0x03));
206-
207-
// Build F32 sign bit
208-
__m256i f32_sign_i32x8 = _mm256_slli_epi32(_mm256_srli_epi32(e5m2_i32x8, 7), 31);
209-
210-
// Normal path: sign | ((exp+112)<<23) | (mant<<21)
211-
__m256i f32_exp_i32x8 = _mm256_slli_epi32(_mm256_add_epi32(exp_i32x8, _mm256_set1_epi32(112)), 23);
212-
__m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 21);
213-
__m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
214-
215-
// Subnormal path: value = mantissa / 65536.0f, then apply sign
216-
__m256 subnorm_abs_f32x8 = _mm256_mul_ps(_mm256_cvtepi32_ps(mant_i32x8), _mm256_set1_ps(1.0f / 65536.0f));
217-
__m256 subnorm_f32x8 = _mm256_or_ps(subnorm_abs_f32x8, _mm256_castsi256_ps(f32_sign_i32x8));
218-
219-
// Blend: if exp==0, use subnormal result; otherwise use normal bits
220-
__m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
221-
return _mm256_blendv_ps(_mm256_castsi256_ps(normal_bits_i32x8), subnorm_f32x8, _mm256_castsi256_ps(exp_zero_mask));
202+
__m128i e5m2_u16x8 = _mm_cvtepu8_epi16(e5m2_i8x8);
203+
__m128i f16_bits_u16x8 = _mm_slli_epi16(e5m2_u16x8, 8);
204+
return _mm256_cvtph_ps(f16_bits_u16x8);
222205
}
223206

224207
/** @brief Convert 8x f32 → 8x e4m3 via bit manipulation (AVX2).

include/numkong/cast/skylake.h

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -198,27 +198,40 @@ NK_INTERNAL __m512 nk_e4m3x16_to_f32x16_skylake_(__m128i e4m3_i8x16) {
198198
return _mm512_mul_ps(fake_f32x16, _mm512_set1_ps(256.0f));
199199
}
200200

201-
/** @brief Convert 16x e5m2 → 16x f32 via bit manipulation (AVX-512).
202-
* E5M2 format: S EEEEE MM (bias=15). F32: sign<<31, (exp+112)<<23, mantissa<<21.
203-
* Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁵⁾ × 2⁻² = mantissa ÷ 65536. */
201+
/** @brief Convert 16x e4m3 → 16x f16 via arithmetic + 8-entry subnormal LUT (AVX-512BW + AVX-512VL).
202+
* E4M3: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
203+
* Normal (exp != 0): F16 = ((lower7 << 7) + 0x2000) | (sign << 8) — bias delta 8 added at the
204+
* exp-position (8 << 10 = 0x2000) after placing magnitude bits at F16 positions 13..7.
205+
* Subnormal (exp == 0): looked up from 8-entry F16 LUT — values 0, 1/512, 2/512, …, 7/512 encoded as
206+
* F16 normals (the smallest E4M3 subnormal 1/512 = 2⁻⁹ is well within F16 normal range).
207+
* NaN (|byte| == 0x7F): blended in as F16 quiet NaN with original sign. */
208+
NK_INTERNAL __m256i nk_e4m3x16_to_f16x16_skylake_(__m128i e4m3_u8x16) {
209+
__m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3_u8x16);
210+
__m256i sign_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16((short)0x80));
211+
__m256i lower7_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F));
212+
__m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 7), _mm256_set1_epi16(0x2000));
213+
__m256i subn_lut_i16x16 = _mm256_set_epi16( //
214+
0x2300, 0x2200, 0x2100, 0x2000, 0x1E00, 0x1C00, 0x1800, 0x0000, 0x2300, 0x2200, 0x2100, 0x2000, 0x1E00, 0x1C00,
215+
0x1800, 0x0000);
216+
__m256i mant_idx_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x07));
217+
__m256i subn_abs_i16x16 = _mm256_permutexvar_epi16(mant_idx_i16x16, subn_lut_i16x16);
218+
__mmask16 is_subnormal = _mm256_testn_epi16_mask(e4m3_i16x16, _mm256_set1_epi16(0x78));
219+
__m256i abs_i16x16 = _mm256_mask_blend_epi16(is_subnormal, normal_abs_i16x16, subn_abs_i16x16);
220+
__m256i shifted_sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
221+
__m256i result_i16x16 = _mm256_or_si256(abs_i16x16, shifted_sign_i16x16);
222+
__mmask16 is_nan = _mm256_cmpeq_epi16_mask(lower7_i16x16, _mm256_set1_epi16(0x7F));
223+
__m256i nan_i16x16 = _mm256_or_si256(shifted_sign_i16x16, _mm256_set1_epi16(0x7E00));
224+
return _mm256_mask_blend_epi16(is_nan, result_i16x16, nan_i16x16);
225+
}
226+
227+
/** @brief Convert 16x e5m2 → 16x f32 via free-shift widen (AVX-512 + F16C).
228+
* E5M2 shares F16's exponent bias (15): `(byte << 8)` is the matching F16 bit
229+
* pattern for every E5M2 value (normals, subnormals, zero, ±Inf, NaN — all
230+
* bit-exact). Widen u8 → u16, shift, then VCVTPH2PS to F32. Three ops total. */
204231
NK_INTERNAL __m512 nk_e5m2x16_to_f32x16_skylake_(__m128i e5m2_i8x16) {
205-
__m512i e5m2_i32x16 = _mm512_cvtepu8_epi32(e5m2_i8x16);
206-
207-
// Extract fields
208-
__m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e5m2_i32x16, 2), _mm512_set1_epi32(0x1F));
209-
__m512i mantissa_i32x16 = _mm512_and_si512(e5m2_i32x16, _mm512_set1_epi32(0x03));
210-
__m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e5m2_i32x16, 7), 31);
211-
212-
// Normal path: sign | ((exp+112)<<23) | (mantissa<<21)
213-
__m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(112)), 23);
214-
__m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 21);
215-
__m512 result_f32x16 = _mm512_castsi512_ps(
216-
_mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
217-
218-
// Subnormal fix: for exp==0 lanes, replace with (mantissa / 65536) | sign using masked OR
219-
__mmask16 is_subnormal = _mm512_testn_epi32_mask(e5m2_i32x16, _mm512_set1_epi32(0x7C));
220-
__m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 65536.0f));
221-
return _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16, _mm512_castsi512_ps(sign_i32x16));
232+
__m256i e5m2_u16x16 = _mm256_cvtepu8_epi16(e5m2_i8x16);
233+
__m256i f16_bits_u16x16 = _mm256_slli_epi16(e5m2_u16x16, 8);
234+
return _mm512_cvtph_ps(f16_bits_u16x16);
222235
}
223236

224237
/** @brief Convert 16x e2m3 → 16x f32 via bit manipulation (AVX-512).
@@ -650,6 +663,18 @@ NK_INTERNAL void nk_partial_load_e4m3x16_to_f32x16_skylake_(void const *src, nk_
650663
dst->zmm_ps = nk_e4m3x16_to_f32x16_skylake_(e4m3_partial.xmm);
651664
}
652665

666+
/** @brief Load 16 e4m3 values and convert to 16 f16 (Skylake AVX-512BW). */
667+
NK_INTERNAL void nk_load_e4m3x16_to_f16x16_skylake_(void const *src, nk_b256_vec_t *dst) {
668+
dst->ymm = nk_e4m3x16_to_f16x16_skylake_(_mm_loadu_si128((__m128i const *)src));
669+
}
670+
671+
/** @brief Partial load of up to 16 e4m3 values with conversion to f16 (Skylake AVX-512BW). */
672+
NK_INTERNAL void nk_partial_load_e4m3x16_to_f16x16_skylake_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
673+
nk_b128_vec_t e4m3_partial;
674+
nk_partial_load_b8x16_skylake_(src, &e4m3_partial, n);
675+
dst->ymm = nk_e4m3x16_to_f16x16_skylake_(e4m3_partial.xmm);
676+
}
677+
653678
/** @brief Load 16 e5m2 values and convert to 16 f32 (Skylake AVX-512). */
654679
NK_INTERNAL void nk_load_e5m2x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
655680
dst->zmm_ps = nk_e5m2x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)src));

include/numkong/dot/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ This processes 64 E4M3 bytes per iteration in u8, doubling the element density o
111111

112112
`nk_dot_e5m2_genoa` converts FP8 values to BF16, then accumulates via `VDPBF16PS`, reusing Genoa's BF16 dot-product instruction for FP8 types.
113113
Each `VDPBF16PS` fuses two BF16 multiply-adds per 32-bit lane at 6-cycle throughput.
114+
On Skylake-X–class CPUs without BF16 dot-product hardware, `nk_dot_e4m3_skylake` / `nk_dot_e5m2_skylake` (and their Haswell twins `nk_dot_e4m3_haswell` / `nk_dot_e5m2_haswell`) instead route through the Giesen-style FP8 → F16 fake-bit-pattern cast, widen via `VCVTPH2PS`, and accumulate in F32 with two independent FMA chains reducing into a single register — avoiding the 3-chain scheduler-stall of the BF16 algebraic form on kernels without native BF16 FMA.
114115
`nk_dot_bf16c_genoa` uses the same instruction for complex BF16, preparing operands with `VPSHUFB` for lane swapping and `VPXORD` with `0x80000000` for sign flips before feeding into `VDPBF16PS`.
115116

116117
### Deferred Sign-Flip in Complex Dot Products

0 commit comments

Comments
 (0)