Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/builder-cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ Install a kernels skill for an AI assistant

Default value: `cuda-kernels`

Possible values: `cuda-kernels`, `rocm-kernels`, `xpu-kernels`
Possible values: `cuda-kernels`, `cpu-kernels`, `rocm-kernels`, `xpu-kernels`

* `--claude` — Install for Claude
* `--codex` — Install for Codex
Expand Down
2 changes: 1 addition & 1 deletion docs/source/builder/agents-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Code agents are a good fit to build custom kernels because the hard part is not just writing in Domain Specific Language (DSLs) like CUDA. You also need the right project layout, PyTorch bindings, architecture-specific choices, model-specific integration, and trustworthy benchmarks.

Kernels on Hugging Face are compatible with agents via skills and the `hf` CLI. The `cuda-kernels`, `rocm-kernels`, and `xpu-kernels` skills contain knowledge so an agent can generate and publish a complete kernel project, instead of isolated snippets.
Kernels on Hugging Face are compatible with agents via skills and the `hf` CLI. The `cuda-kernels`, `rocm-kernels`, `xpu-kernels`, and `cpu-kernels` skills contain knowledge so an agent can generate and publish a complete kernel project, instead of isolated snippets.

This guide is for **authoring new kernels**. If you only want to **load an existing precompiled kernel**, use `get_kernel()` instead.

Expand Down
12 changes: 12 additions & 0 deletions docs/source/cli-skills.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,23 @@ Supported skills include:
- `cuda-kernels` (default)
- `rocm-kernels`
- `xpu-kernels`
- `cpu-kernels`

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to add a note on where CPU kernels are actually helpful.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Please review the new changes and rerun the CI. Thanks!


Skill files are downloaded from the `huggingface/kernels` directory in this [repository](https://github.com/huggingface/kernels/tree/main/kernel-builder/skills).

Skills instruct agents how to deal with hardware-specific optimizations, integrate with libraries like diffusers and transformers, and benchmark kernel performance in consistent ways.

> [!TIP]
> **When are CPU kernels actually helpful?** Two main cases:
> - **Better performance on Intel Xeon** — custom AVX2/AVX512 kernels (and AMX via brgemm for quantized GEMM) outperform generic PyTorch ops for element-wise and quantized workloads, especially in CPU-only or latency-sensitive serving.
> - **Enabling functionality that otherwise can't run** — some kernels are a hard requirement, e.g. `megablocks` MoE on CPU, where without the kernel you simply cannot run MXFP4.
Comment on lines +14 to +17

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Can you provide some example kernels that you have built for CPU?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


Example CPU kernels built with this skill (available on the Hub under [`kernels-community`](https://huggingface.co/kernels-community)):

- [`kernels-community/megablocks`](https://huggingface.co/kernels-community/megablocks) — MoE kernels with a CPU backend that enable running MXFP4 MoE models on CPU.
- [`kernels-community/quantization-gptq`](https://huggingface.co/kernels-community/quantization-gptq) — INT4 quantized GEMM using AVX512.
- [`kernels-community/rmsnorm`](https://huggingface.co/kernels-community/rmsnorm) — RMSNorm with AVX2/AVX512 element-wise paths.

Examples:

```bash
Expand Down
455 changes: 455 additions & 0 deletions kernel-builder/skills/cpu-kernels/SKILL.md

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions kernel-builder/skills/cpu-kernels/manifest.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
SKILL.md
manifest.txt
scripts/config.yaml
scripts/config.py
scripts/analyze_op.py
scripts/validate_cpu_kernel.py
scripts/benchmark_cpu.py
scripts/cpu_profiler.py
scripts/trial_manager.py
references/correctness.yaml
references/runtime_dispatch.yaml
references/build_system.yaml
references/simd_optimization_patterns.yaml
references/quantized_gemm_patterns.yaml
references/brgemm_patterns.yaml
references/memory_patterns.yaml
references/threading_patterns.yaml
references/dtype_optimizations.yaml
references/optimization_levels.yaml
references/implementation_reference.md
references/optimization_strategies.md
references/workflow_details.md
references/huggingface-kernels-integration.md
287 changes: 287 additions & 0 deletions kernel-builder/skills/cpu-kernels/references/brgemm_patterns.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# brgemm and AMX Patterns
#
# AMX is NOT used directly in HF kernels. Instead, kernels call
# at::native::cpublas::brgemm() which wraps oneDNN brgemm, and
# oneDNN internally dispatches to AMX tile instructions when available.
#
# Source: kernels-community flash-attn2, megablocks, quantization-bitsandbytes

overview:
name: "brgemm — AMX-accelerated GEMM via PyTorch/oneDNN"
description: |
AMX (Advanced Matrix Extensions) provides hardware tile multiply on Intel Xeon 4+,
but it's too complex to call directly (tile configuration, register allocation,
data layout constraints). Instead, CPU kernels use the brgemm API:

Kernel code → at::native::cpublas::brgemm() → oneDNN brgemm → AMX tiles

The kernel developer's responsibilities:
1. Pack weight data in VNNI format (2-element interleave for bf16)
2. Tile the GEMM loop to match AMX-friendly sizes (TILE_M=16, TILE_N=16, TILE_K=32)
3. Call brgemm() for each tile
4. Clean up via brgemm_release()

For small M (≤ 4 for bf16), brgemm overhead is too high — fall back to
tinygemm_kernel using AVX512 _mm512_dpbf16_ps intrinsics.

compiler_flags: |
# While brgemm/oneDNN dispatches to AMX internally, it is highly recommended
# to include AMX flags for GEMM kernels (as done in megablocks and flash-attn2)
# to ensure any explicit AMX packing instructions compile successfully:
cxx-flags = ["-mavx512f", "-mavx512bf16", "-mavx512vl", "-mavx512dq", "-mavx512bw", "-mavx512vbmi", "-mamx-tile", "-mamx-bf16", "-mamx-int8", "-mfma", "-mf16c", "-fopenmp"]

brgemm_api:
name: "at::native::cpublas::brgemm API"
signature: |
void at::native::cpublas::brgemm(
int64_t M, // rows of output tile
int64_t N, // cols of output tile
int64_t K, // reduction dimension
int64_t lda, // leading dim of A (typically K)
int64_t ldb, // leading dim of B (typically N)
int64_t ldc, // leading dim of C (typically BLOCK_N)
bool add_C, // if true, accumulate onto existing C; if false, overwrite
const void* A, // input activation (bf16)
const void* B, // weight, VNNI-packed (bf16)
float* C // output accumulator (fp32)
);

void at::native::cpublas::brgemm_release(); // cleanup, call after all GEMMs done
example: |
#include <ATen/native/CPUBlas.h>

// GEMM: C[M×N] = A[M×K] × B[K×N], A is bf16, B is VNNI-packed bf16, C is fp32
at::native::cpublas::brgemm(
m_size, // M
n_size, // N
K, // K
K, // lda
n_size, // ldb
BLOCK_N, // ldc
false, // overwrite C
A_ptr, // bf16*
B_vnni_ptr, // bf16*, VNNI-packed
C_ptr); // float*

// Always clean up after all brgemm calls are done
at::native::cpublas::brgemm_release();

vnni_packing:
name: "VNNI Data Layout"
description: |
AMX/VNNI requires weight matrix B to be packed in VNNI format:
- BF16: 2 consecutive K elements interleaved per N position
- INT8: 4 consecutive K elements interleaved per N position

Standard layout B[K][N]: VNNI layout B[K/2][N][2] (bf16):
[b00 b01 b02 b03] [(b00,b40) (b01,b41) (b02,b42) (b03,b43)]
[b10 b11 b12 b13] [(b10,b50) (b11,b51) (b12,b52) (b13,b53)]
[b20 b21 b22 b23] ...
[b30 b31 b32 b33]
[b40 b41 b42 b43]
...
pack_function: |
// Pack B from [K][N] to VNNI [K/vnni_block][N][vnni_block]
template<typename T>
void pack_vnni(T* dst, const T* src, int K, int N) {
constexpr int VNNI_BLK = std::is_same_v<T, at::BFloat16> ? 2 : 4;
for (int k = 0; k < K; k += VNNI_BLK) {
for (int n = 0; n < N; n++) {
for (int v = 0; v < VNNI_BLK; v++) {
dst[(k / VNNI_BLK) * N * VNNI_BLK + n * VNNI_BLK + v] =
src[(k + v) * N + n];
}
}
}
}

weight_conversion:
name: "Weight Conversion for brgemm (CRITICAL)"
description: |
brgemm ALWAYS requires matrix B in VNNI-interleaved layout. Every kernel
that calls brgemm must convert B beforehand. The strategy differs by kernel:

strategies:
megablocks_persistent:
name: "Megablocks MoE — Persistent Conversion at First Forward"
description: |
MoE expert weights are static, so VNNI packing is done ONCE at the first
forward call. The Python wrapper (cpu_moe_cpp.py) drives this:

1. Transpose weights: gate_up_proj.data.transpose(-1, -2).contiguous()
2. Call ops.convert_weight_packed(data) — C++ does VNNI packing
3. Store result back: self.experts.gate_up_proj.data = packed_data
4. Set self.packed_weight = True
5. All subsequent forwards pass is_vnni=True, skipping conversion

For MXFP4 quantized models, also call ops.convert_scale_packed(scales)
to reorder scales from [E, N, G] → [E, NB, G, BLOCK_N] for cache
efficiency during the GEMM inner loop.
convert_weight_packed_dtypes: |
| dtype | VNNI block | Layout after packing | Extra |
|-------------|-----------|-------------------------------|------------------------------|
| bf16/fp16 | 2 | [IC/2, N, 2] | — |
| int8 | 4 | [IC/4, N, 4] | + s8s8 compensation suffix |
| fp8_e4m3 | 2 | [IC/2, N, 2] (same as bf16) | — |
| uint8 (mxfp4/int4) | special | nibble unpack + 32-way repack | get_row_size(K) = K >> 1 |
python_code: |
# In CPUMegaBlocksMoeMLP.forward() — first call only:
if not self.packed_weight:
data_1 = self.experts.gate_up_proj.data.transpose(-1, -2).contiguous()
data_2 = self.experts.down_proj.data.transpose(-1, -2).contiguous()
if self.use_mxfp4:
self.experts.gate_up_proj.storage.data = ops.convert_weight_packed(data_1)
self.experts.down_proj.storage.data = ops.convert_weight_packed(data_2)
# Also convert scales for MXFP4
self.experts.gate_up_proj_precision_config.weight_scale.storage.data = \
ops.convert_scale_packed(scale_data.transpose(-1, -2).contiguous())
else:
data_1 = data_1.to(torch.bfloat16) if data_1.dtype == torch.float32 else data_1
self.experts.gate_up_proj.data = ops.convert_weight_packed(data_1)
self.packed_weight = True
cpp_fallback: |
// In C++ fused_experts: if is_vnni=false, convert on-the-fly (slow)
auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1);

quantized_fused_dequant:
name: "GPTQ / BnB — Block-Interleaved Weight + Fused Dequant per-forward"
description: |
The C++ kernel receives weights ALREADY converted to block-interleaved format
by the model framework (GPTQModel / bitsandbytes). The kernel does NOT convert
raw checkpoint weights — that's done externally.

## External conversion (done ONCE at first forward):
- GPTQ: GPTQModel's transform_cpu() unpacks int32→uint8, reorders by g_idx,
transposes to [N,K]; then convert_weight_packed_zp() repacks to [N,K/2]
with BLOCK_N=32 interleaving. See quantized_gemm_patterns.yaml for details.
- BnB: bitsandbytes' _convert_weight_packed_for_cpu() unpacks nibbles→[N,K],
repacks to [N,K/2] with same BLOCK_N=32 interleaving. Also transposes
absmax to [K/blocksize, N] bf16.

## Kernel-side dequant (per-forward, after receiving pre-converted weights):
Quantized GEMM uses two threshold variables to decide how to handle matrix B:
- `use_brgemm` (e.g. M > 4 for bf16): Switch from tinygemm (fused dequant in loop) to brgemm.
- `use_brgemm_dequant_out` (e.g. M > 100): Control WHEN the unpack_B (dequant+VNNI) happens.

brgemm dequant paths:
- use_brgemm_dequant_out = true (M > 100):
Unpack ALL available blocks of B upfront into a single large Btmp tensor before the M*N loop.
Then execute brgemm using this pre-dequantized buffer. Better for Prefill phase (large M).
- use_brgemm_dequant_out = false (4 < M ≤ 100):
Unpack per K-block (BLOCK_K=128) into a small inner Btmp_inner buffer during the GEMM loop.
Executes brgemm per block. Better for Decode phase (small/medium M) to save L3 cache / memory bandwidth.

unpack_B flow:
1. Load packed byte from block-interleaved layout → nibble split
2. Subtract zero_point per group (GPTQ) or skip (BnB: zero in LUT)
3. LUT lookup (NF4/FP4) or linear scale (INT4)
4. Multiply by scale
5. _mm512_cvtne2ps_pbh → bf16 pair (naturally VNNI 2-element)
6. Store to Btmp[K/2][N][2]

The block-interleaved layout ensures 32 consecutive N-elements are together
in memory, enabling efficient AVX512 loads via _mm512_permutexvar_epi8.

flash_attention_per_tile:
name: "Flash Attention — pack_vnni per tile (no persistent convert)"
description: |
K and V matrices change every forward, so no caching is possible.
Before each Q@K^T: pack_vnni(Btmp, K_tile)
Before each S@V: pack_vnni2(Btmp, V_tile)
Uses AVX512 16×16 transpose for efficient tile packing.

rmsnorm_none:
name: "RMSNorm / Element-wise — No conversion needed"
description: |
Element-wise kernels (RMSNorm, activations, reductions) do not use
brgemm and need no weight conversion. Weight tensor used as-is.

decision_guide: |
┌─────────────────────────────────┬────────────────────────────────────────────┐
│ Weight type │ Strategy │
├─────────────────────────────────┼────────────────────────────────────────────┤
│ Static bf16/fp16 (MoE experts) │ convert_weight_packed() once, cache result │
│ Static bf16/fp16 + MXFP4 scales │ convert_weight_packed() + convert_scale_packed() │
│ Quantized INT4/NF4/FP4 │ Pre-converted by framework; kernel dequants per-forward │
│ Dynamic matrices (K, V) │ pack_vnni() per tile, no caching │
│ Element-wise (RMSNorm etc) │ No conversion needed │
└─────────────────────────────────┴────────────────────────────────────────────┘

Key rule: if B doesn't change between forwards → convert ONCE and cache.
If B changes (activations, K/V, or is dequantized per-forward) → convert each time.

tinygemm_vs_brgemm:
name: "Algorithm Selection: tinygemm vs brgemm"
description: |
Two GEMM paths exist for different M sizes:

1. **tinygemm** (M ≤ 4 for bf16): Hand-written AVX512 kernel
- Uses _mm512_dpbf16_ps (VNNI dot-product, NOT AMX tiles)
- Lower overhead, better for small batch / single token
- Kernel author writes this directly with SIMD intrinsics

2. **brgemm** (M > 4 for bf16): oneDNN-backed, AMX-accelerated
- Higher overhead but much higher throughput for larger M
- Kernel author just calls the API, oneDNN handles AMX internally
selection_pattern: |
template<typename scalar_t>
bool can_use_brgemm(int M) {
if constexpr (std::is_same_v<scalar_t, at::BFloat16>) {
return M > 4; // bf16: brgemm needs M > 4
}
return M > 8; // other types: higher threshold
}

// In the kernel:
if (can_use_brgemm<scalar_t>(M)) {
// Tile loop calling brgemm()
for (int m = 0; m < M; m += TILE_M) {
for (int n = 0; n < N; n += TILE_N) {
at::native::cpublas::brgemm(
std::min(TILE_M, M - m), std::min(TILE_N, N - n), K,
K, N, BLOCK_N, false,
A + m * K, B_vnni + n * K_vnni, C + m * BLOCK_N + n);
}
}
at::native::cpublas::brgemm_release();
} else {
// Small M: use hand-written tinygemm with _mm512_dpbf16_ps
tinygemm_kernel(M, N, K, A, B_vnni, C);
}

tile_sizes:
name: "AMX-Friendly Tile Dimensions"
description: |
When tiling for brgemm, use sizes that match AMX hardware tiles:
values:
TILE_M: 16 # AMX processes 16 rows at a time
TILE_N: 16 # AMX output is 16 cols (fp32)
TILE_K: 32 # AMX-BF16 processes 32 bf16 elements per step
note: |
These are the hardware tile dimensions. The outer blocking
(BLOCK_M, BLOCK_N) can be multiples of these for better
cache utilization, e.g. BLOCK_M=64 means 4 AMX tiles in M.

when_to_use_brgemm:
applies:
- "Flash Attention: QK^T and S×V matmuls (M = seq_len, can be large)"
- "Quantized GEMM: after dequantizing to bf16, M > 4"
- "MoE expert GEMM: avg tokens per expert > 4"
does_not_apply:
- "Element-wise ops (RMSNorm, activation) — use AVX512 intrinsics directly"
- "Small reductions (single-vector operations)"
- "Single-token decode (M=1) — use tinygemm with _mm512_dpbf16_ps"
runtime_detection: |
// Check AMX availability for brgemm
// cpu_features.hpp already provides:
static bool hasAMX() {
unsigned int eax, ebx, ecx, edx;
__cpuid_count(7, 0, eax, ebx, ecx, edx);
bool amx_bf16 = (edx & (1 << 22)) != 0;
bool amx_tile = (edx & (1 << 24)) != 0;
return amx_bf16 && amx_tile;
}
// Note: brgemm() will still work without AMX — oneDNN falls back
// to AVX512 internally. But performance won't be as good.
// XCR0 check for XTILEDATA/XTILECFG (bits 17,18) is not implemented
// in existing kernels — they rely on brgemm to handle this internally.
Loading
Loading