-
Notifications
You must be signed in to change notification settings - Fork 105
Add CPU kernel skills #614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
64fa4de
add cpu kernel skills
jiqing-feng 467a0f9
update
jiqing-feng 2078b8c
update
jiqing-feng 80c064b
update
jiqing-feng 27fe4e5
Merge branch 'main' into cpu_skill
jiqing-feng 7696fb2
update
jiqing-feng aa6a03d
update
jiqing-feng 9ab3cb7
Merge branch 'main' into cpu_skill
jiqing-feng 0b407ad
update
jiqing-feng ba4f61f
Merge branch 'main' into cpu_skill
jiqing-feng 1fa4994
Merge branch 'main' into cpu_skill
jiqing-feng 2a84e74
update
jiqing-feng 7cb6da9
fix style
jiqing-feng 1d4c374
add examples
jiqing-feng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,11 +5,23 @@ Supported skills include: | |
| - `cuda-kernels` (default) | ||
| - `rocm-kernels` | ||
| - `xpu-kernels` | ||
| - `cpu-kernels` | ||
|
|
||
| Skill files are downloaded from the `huggingface/kernels` directory in this [repository](https://github.com/huggingface/kernels/tree/main/kernel-builder/skills). | ||
|
|
||
| Skills instruct agents how to deal with hardware-specific optimizations, integrate with libraries like diffusers and transformers, and benchmark kernel performance in consistent ways. | ||
|
|
||
| > [!TIP] | ||
| > **When are CPU kernels actually helpful?** Two main cases: | ||
| > - **Better performance on Intel Xeon** — custom AVX2/AVX512 kernels (and AMX via brgemm for quantized GEMM) outperform generic PyTorch ops for element-wise and quantized workloads, especially in CPU-only or latency-sensitive serving. | ||
| > - **Enabling functionality that otherwise can't run** — some kernels are a hard requirement, e.g. `megablocks` MoE on CPU, where without the kernel you simply cannot run MXFP4. | ||
|
Comment on lines
+14
to
+17
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! Can you provide some example kernels that you have built for CPU?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
|
||
| Example CPU kernels built with this skill (available on the Hub under [`kernels-community`](https://huggingface.co/kernels-community)): | ||
|
|
||
| - [`kernels-community/megablocks`](https://huggingface.co/kernels-community/megablocks) — MoE kernels with a CPU backend that enable running MXFP4 MoE models on CPU. | ||
| - [`kernels-community/quantization-gptq`](https://huggingface.co/kernels-community/quantization-gptq) — INT4 quantized GEMM using AVX512. | ||
| - [`kernels-community/rmsnorm`](https://huggingface.co/kernels-community/rmsnorm) — RMSNorm with AVX2/AVX512 element-wise paths. | ||
|
|
||
| Examples: | ||
|
|
||
| ```bash | ||
|
|
||
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| SKILL.md | ||
| manifest.txt | ||
| scripts/config.yaml | ||
| scripts/config.py | ||
| scripts/analyze_op.py | ||
| scripts/validate_cpu_kernel.py | ||
| scripts/benchmark_cpu.py | ||
| scripts/cpu_profiler.py | ||
| scripts/trial_manager.py | ||
| references/correctness.yaml | ||
| references/runtime_dispatch.yaml | ||
| references/build_system.yaml | ||
| references/simd_optimization_patterns.yaml | ||
| references/quantized_gemm_patterns.yaml | ||
| references/brgemm_patterns.yaml | ||
| references/memory_patterns.yaml | ||
| references/threading_patterns.yaml | ||
| references/dtype_optimizations.yaml | ||
| references/optimization_levels.yaml | ||
| references/implementation_reference.md | ||
| references/optimization_strategies.md | ||
| references/workflow_details.md | ||
| references/huggingface-kernels-integration.md |
287 changes: 287 additions & 0 deletions
287
kernel-builder/skills/cpu-kernels/references/brgemm_patterns.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,287 @@ | ||
| # brgemm and AMX Patterns | ||
| # | ||
| # AMX is NOT used directly in HF kernels. Instead, kernels call | ||
| # at::native::cpublas::brgemm() which wraps oneDNN brgemm, and | ||
| # oneDNN internally dispatches to AMX tile instructions when available. | ||
| # | ||
| # Source: kernels-community flash-attn2, megablocks, quantization-bitsandbytes | ||
|
|
||
| overview: | ||
| name: "brgemm — AMX-accelerated GEMM via PyTorch/oneDNN" | ||
| description: | | ||
| AMX (Advanced Matrix Extensions) provides hardware tile multiply on Intel Xeon 4+, | ||
| but it's too complex to call directly (tile configuration, register allocation, | ||
| data layout constraints). Instead, CPU kernels use the brgemm API: | ||
|
|
||
| Kernel code → at::native::cpublas::brgemm() → oneDNN brgemm → AMX tiles | ||
|
|
||
| The kernel developer's responsibilities: | ||
| 1. Pack weight data in VNNI format (2-element interleave for bf16) | ||
| 2. Tile the GEMM loop to match AMX-friendly sizes (TILE_M=16, TILE_N=16, TILE_K=32) | ||
| 3. Call brgemm() for each tile | ||
| 4. Clean up via brgemm_release() | ||
|
|
||
| For small M (≤ 4 for bf16), brgemm overhead is too high — fall back to | ||
| tinygemm_kernel using AVX512 _mm512_dpbf16_ps intrinsics. | ||
|
|
||
| compiler_flags: | | ||
| # While brgemm/oneDNN dispatches to AMX internally, it is highly recommended | ||
| # to include AMX flags for GEMM kernels (as done in megablocks and flash-attn2) | ||
| # to ensure any explicit AMX packing instructions compile successfully: | ||
| cxx-flags = ["-mavx512f", "-mavx512bf16", "-mavx512vl", "-mavx512dq", "-mavx512bw", "-mavx512vbmi", "-mamx-tile", "-mamx-bf16", "-mamx-int8", "-mfma", "-mf16c", "-fopenmp"] | ||
|
|
||
| brgemm_api: | ||
| name: "at::native::cpublas::brgemm API" | ||
| signature: | | ||
| void at::native::cpublas::brgemm( | ||
| int64_t M, // rows of output tile | ||
| int64_t N, // cols of output tile | ||
| int64_t K, // reduction dimension | ||
| int64_t lda, // leading dim of A (typically K) | ||
| int64_t ldb, // leading dim of B (typically N) | ||
| int64_t ldc, // leading dim of C (typically BLOCK_N) | ||
| bool add_C, // if true, accumulate onto existing C; if false, overwrite | ||
| const void* A, // input activation (bf16) | ||
| const void* B, // weight, VNNI-packed (bf16) | ||
| float* C // output accumulator (fp32) | ||
| ); | ||
|
|
||
| void at::native::cpublas::brgemm_release(); // cleanup, call after all GEMMs done | ||
| example: | | ||
| #include <ATen/native/CPUBlas.h> | ||
|
|
||
| // GEMM: C[M×N] = A[M×K] × B[K×N], A is bf16, B is VNNI-packed bf16, C is fp32 | ||
| at::native::cpublas::brgemm( | ||
| m_size, // M | ||
| n_size, // N | ||
| K, // K | ||
| K, // lda | ||
| n_size, // ldb | ||
| BLOCK_N, // ldc | ||
| false, // overwrite C | ||
| A_ptr, // bf16* | ||
| B_vnni_ptr, // bf16*, VNNI-packed | ||
| C_ptr); // float* | ||
|
|
||
| // Always clean up after all brgemm calls are done | ||
| at::native::cpublas::brgemm_release(); | ||
|
|
||
| vnni_packing: | ||
| name: "VNNI Data Layout" | ||
| description: | | ||
| AMX/VNNI requires weight matrix B to be packed in VNNI format: | ||
| - BF16: 2 consecutive K elements interleaved per N position | ||
| - INT8: 4 consecutive K elements interleaved per N position | ||
|
|
||
| Standard layout B[K][N]: VNNI layout B[K/2][N][2] (bf16): | ||
| [b00 b01 b02 b03] [(b00,b40) (b01,b41) (b02,b42) (b03,b43)] | ||
| [b10 b11 b12 b13] [(b10,b50) (b11,b51) (b12,b52) (b13,b53)] | ||
| [b20 b21 b22 b23] ... | ||
| [b30 b31 b32 b33] | ||
| [b40 b41 b42 b43] | ||
| ... | ||
| pack_function: | | ||
| // Pack B from [K][N] to VNNI [K/vnni_block][N][vnni_block] | ||
| template<typename T> | ||
| void pack_vnni(T* dst, const T* src, int K, int N) { | ||
| constexpr int VNNI_BLK = std::is_same_v<T, at::BFloat16> ? 2 : 4; | ||
| for (int k = 0; k < K; k += VNNI_BLK) { | ||
| for (int n = 0; n < N; n++) { | ||
| for (int v = 0; v < VNNI_BLK; v++) { | ||
| dst[(k / VNNI_BLK) * N * VNNI_BLK + n * VNNI_BLK + v] = | ||
| src[(k + v) * N + n]; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| weight_conversion: | ||
| name: "Weight Conversion for brgemm (CRITICAL)" | ||
| description: | | ||
| brgemm ALWAYS requires matrix B in VNNI-interleaved layout. Every kernel | ||
| that calls brgemm must convert B beforehand. The strategy differs by kernel: | ||
|
|
||
| strategies: | ||
| megablocks_persistent: | ||
| name: "Megablocks MoE — Persistent Conversion at First Forward" | ||
| description: | | ||
| MoE expert weights are static, so VNNI packing is done ONCE at the first | ||
| forward call. The Python wrapper (cpu_moe_cpp.py) drives this: | ||
|
|
||
| 1. Transpose weights: gate_up_proj.data.transpose(-1, -2).contiguous() | ||
| 2. Call ops.convert_weight_packed(data) — C++ does VNNI packing | ||
| 3. Store result back: self.experts.gate_up_proj.data = packed_data | ||
| 4. Set self.packed_weight = True | ||
| 5. All subsequent forwards pass is_vnni=True, skipping conversion | ||
|
|
||
| For MXFP4 quantized models, also call ops.convert_scale_packed(scales) | ||
| to reorder scales from [E, N, G] → [E, NB, G, BLOCK_N] for cache | ||
| efficiency during the GEMM inner loop. | ||
| convert_weight_packed_dtypes: | | ||
| | dtype | VNNI block | Layout after packing | Extra | | ||
| |-------------|-----------|-------------------------------|------------------------------| | ||
| | bf16/fp16 | 2 | [IC/2, N, 2] | — | | ||
| | int8 | 4 | [IC/4, N, 4] | + s8s8 compensation suffix | | ||
| | fp8_e4m3 | 2 | [IC/2, N, 2] (same as bf16) | — | | ||
| | uint8 (mxfp4/int4) | special | nibble unpack + 32-way repack | get_row_size(K) = K >> 1 | | ||
| python_code: | | ||
| # In CPUMegaBlocksMoeMLP.forward() — first call only: | ||
| if not self.packed_weight: | ||
| data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous() | ||
| data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous() | ||
| if self.use_mxfp4: | ||
| self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1) | ||
| self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2) | ||
| # Also convert scales for MXFP4 | ||
| self.experts.gate_up_proj_precision_config.weight_scale.storage.data = \ | ||
| ops.convert_scale_packed(scale_data.transpose(-1, -2).contiguous()) | ||
| else: | ||
| data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1 | ||
| self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1) | ||
| self.packed_weight = True | ||
| cpp_fallback: | | ||
| // In C++ fused_experts: if is_vnni=false, convert on-the-fly (slow) | ||
| auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1); | ||
|
|
||
| quantized_fused_dequant: | ||
| name: "GPTQ / BnB — Block-Interleaved Weight + Fused Dequant per-forward" | ||
| description: | | ||
| The C++ kernel receives weights ALREADY converted to block-interleaved format | ||
| by the model framework (GPTQModel / bitsandbytes). The kernel does NOT convert | ||
| raw checkpoint weights — that's done externally. | ||
|
|
||
| ## External conversion (done ONCE at first forward): | ||
| - GPTQ: GPTQModel's transform_cpu() unpacks int32→uint8, reorders by g_idx, | ||
| transposes to [N,K]; then convert_weight_packed_zp() repacks to [N,K/2] | ||
| with BLOCK_N=32 interleaving. See quantized_gemm_patterns.yaml for details. | ||
| - BnB: bitsandbytes' _convert_weight_packed_for_cpu() unpacks nibbles→[N,K], | ||
| repacks to [N,K/2] with same BLOCK_N=32 interleaving. Also transposes | ||
| absmax to [K/blocksize, N] bf16. | ||
|
|
||
| ## Kernel-side dequant (per-forward, after receiving pre-converted weights): | ||
| Quantized GEMM uses two threshold variables to decide how to handle matrix B: | ||
| - `use_brgemm` (e.g. M > 4 for bf16): Switch from tinygemm (fused dequant in loop) to brgemm. | ||
| - `use_brgemm_dequant_out` (e.g. M > 100): Control WHEN the unpack_B (dequant+VNNI) happens. | ||
|
|
||
| brgemm dequant paths: | ||
| - use_brgemm_dequant_out = true (M > 100): | ||
| Unpack ALL available blocks of B upfront into a single large Btmp tensor before the M*N loop. | ||
| Then execute brgemm using this pre-dequantized buffer. Better for Prefill phase (large M). | ||
| - use_brgemm_dequant_out = false (4 < M ≤ 100): | ||
| Unpack per K-block (BLOCK_K=128) into a small inner Btmp_inner buffer during the GEMM loop. | ||
| Executes brgemm per block. Better for Decode phase (small/medium M) to save L3 cache / memory bandwidth. | ||
|
|
||
| unpack_B flow: | ||
| 1. Load packed byte from block-interleaved layout → nibble split | ||
| 2. Subtract zero_point per group (GPTQ) or skip (BnB: zero in LUT) | ||
| 3. LUT lookup (NF4/FP4) or linear scale (INT4) | ||
| 4. Multiply by scale | ||
| 5. _mm512_cvtne2ps_pbh → bf16 pair (naturally VNNI 2-element) | ||
| 6. Store to Btmp[K/2][N][2] | ||
|
|
||
| The block-interleaved layout ensures 32 consecutive N-elements are together | ||
| in memory, enabling efficient AVX512 loads via _mm512_permutexvar_epi8. | ||
|
|
||
| flash_attention_per_tile: | ||
| name: "Flash Attention — pack_vnni per tile (no persistent convert)" | ||
| description: | | ||
| K and V matrices change every forward, so no caching is possible. | ||
| Before each Q@K^T: pack_vnni(Btmp, K_tile) | ||
| Before each S@V: pack_vnni2(Btmp, V_tile) | ||
| Uses AVX512 16×16 transpose for efficient tile packing. | ||
|
|
||
| rmsnorm_none: | ||
| name: "RMSNorm / Element-wise — No conversion needed" | ||
| description: | | ||
| Element-wise kernels (RMSNorm, activations, reductions) do not use | ||
| brgemm and need no weight conversion. Weight tensor used as-is. | ||
|
|
||
| decision_guide: | | ||
| ┌─────────────────────────────────┬────────────────────────────────────────────┐ | ||
| │ Weight type │ Strategy │ | ||
| ├─────────────────────────────────┼────────────────────────────────────────────┤ | ||
| │ Static bf16/fp16 (MoE experts) │ convert_weight_packed() once, cache result │ | ||
| │ Static bf16/fp16 + MXFP4 scales │ convert_weight_packed() + convert_scale_packed() │ | ||
| │ Quantized INT4/NF4/FP4 │ Pre-converted by framework; kernel dequants per-forward │ | ||
| │ Dynamic matrices (K, V) │ pack_vnni() per tile, no caching │ | ||
| │ Element-wise (RMSNorm etc) │ No conversion needed │ | ||
| └─────────────────────────────────┴────────────────────────────────────────────┘ | ||
|
|
||
| Key rule: if B doesn't change between forwards → convert ONCE and cache. | ||
| If B changes (activations, K/V, or is dequantized per-forward) → convert each time. | ||
|
|
||
| tinygemm_vs_brgemm: | ||
| name: "Algorithm Selection: tinygemm vs brgemm" | ||
| description: | | ||
| Two GEMM paths exist for different M sizes: | ||
|
|
||
| 1. **tinygemm** (M ≤ 4 for bf16): Hand-written AVX512 kernel | ||
| - Uses _mm512_dpbf16_ps (VNNI dot-product, NOT AMX tiles) | ||
| - Lower overhead, better for small batch / single token | ||
| - Kernel author writes this directly with SIMD intrinsics | ||
|
|
||
| 2. **brgemm** (M > 4 for bf16): oneDNN-backed, AMX-accelerated | ||
| - Higher overhead but much higher throughput for larger M | ||
| - Kernel author just calls the API, oneDNN handles AMX internally | ||
| selection_pattern: | | ||
| template<typename scalar_t> | ||
| bool can_use_brgemm(int M) { | ||
| if constexpr (std::is_same_v<scalar_t, at::BFloat16>) { | ||
| return M > 4; // bf16: brgemm needs M > 4 | ||
| } | ||
| return M > 8; // other types: higher threshold | ||
| } | ||
|
|
||
| // In the kernel: | ||
| if (can_use_brgemm<scalar_t>(M)) { | ||
| // Tile loop calling brgemm() | ||
| for (int m = 0; m < M; m += TILE_M) { | ||
| for (int n = 0; n < N; n += TILE_N) { | ||
| at::native::cpublas::brgemm( | ||
| std::min(TILE_M, M - m), std::min(TILE_N, N - n), K, | ||
| K, N, BLOCK_N, false, | ||
| A + m * K, B_vnni + n * K_vnni, C + m * BLOCK_N + n); | ||
| } | ||
| } | ||
| at::native::cpublas::brgemm_release(); | ||
| } else { | ||
| // Small M: use hand-written tinygemm with _mm512_dpbf16_ps | ||
| tinygemm_kernel(M, N, K, A, B_vnni, C); | ||
| } | ||
|
|
||
| tile_sizes: | ||
| name: "AMX-Friendly Tile Dimensions" | ||
| description: | | ||
| When tiling for brgemm, use sizes that match AMX hardware tiles: | ||
| values: | ||
| TILE_M: 16 # AMX processes 16 rows at a time | ||
| TILE_N: 16 # AMX output is 16 cols (fp32) | ||
| TILE_K: 32 # AMX-BF16 processes 32 bf16 elements per step | ||
| note: | | ||
| These are the hardware tile dimensions. The outer blocking | ||
| (BLOCK_M, BLOCK_N) can be multiples of these for better | ||
| cache utilization, e.g. BLOCK_M=64 means 4 AMX tiles in M. | ||
|
|
||
| when_to_use_brgemm: | ||
| applies: | ||
| - "Flash Attention: QK^T and S×V matmuls (M = seq_len, can be large)" | ||
| - "Quantized GEMM: after dequantizing to bf16, M > 4" | ||
| - "MoE expert GEMM: avg tokens per expert > 4" | ||
| does_not_apply: | ||
| - "Element-wise ops (RMSNorm, activation) — use AVX512 intrinsics directly" | ||
| - "Small reductions (single-vector operations)" | ||
| - "Single-token decode (M=1) — use tinygemm with _mm512_dpbf16_ps" | ||
| runtime_detection: | | ||
| // Check AMX availability for brgemm | ||
| // cpu_features.hpp already provides: | ||
| static bool hasAMX() { | ||
| unsigned int eax, ebx, ecx, edx; | ||
| __cpuid_count(7, 0, eax, ebx, ecx, edx); | ||
| bool amx_bf16 = (edx & (1 << 22)) != 0; | ||
| bool amx_tile = (edx & (1 << 24)) != 0; | ||
| return amx_bf16 && amx_tile; | ||
| } | ||
| // Note: brgemm() will still work without AMX — oneDNN falls back | ||
| // to AVX512 internally. But performance won't be as good. | ||
| // XCR0 check for XTILEDATA/XTILECFG (bits 17,18) is not implemented | ||
| // in existing kernels — they rely on brgemm to handle this internally. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to add a note on where CPU kernels are actually helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Please review the new changes and rerun the CI. Thanks!