Skip to content

Commit e9561b8

Browse files
committed
Fuse dispatch for the legacy interleaved CQ4 GEMV
Replace the legacy IL GEMV's three-pass dispatch (pool-parallel Hadamard + cv wait, serial int8 quantize, static parallel_ranges + cv wait) with the panel driver's single fused dispatch: group-stolen phase A behind a spin barrier, dynamic 16-block-chunk stealing in phase B, main thread as worker 0 with a spin-join, shared CACTUS_GEMV_SB_PER_THREAD budget. The IL micro-kernel is unchanged. M4 kernel: kv_proj 44.7 -> 158.2 GF, o_proj 59.2 -> 188.4 GF. E2E decode on legacy bundles: gemma-4-e2b-it +19%, qwen3-1.7b +42%, lfm2-350m +26%, reaching panel-NEON parity on the unchanged file format. Signed-off-by: Noah Cylich <noahcylich@gmail.com>
1 parent ff35ac9 commit e9561b8

2 files changed

Lines changed: 71 additions & 32 deletions

File tree

cactus-kernels/src/matmul.cpp

Lines changed: 65 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,15 @@ static void cactus_quant_panel_rescale_sb(const int32_t* psb, size_t sb, const f
16941694
for (uint32_t c = 0; c < valid; ++c) C[n_start + c] = static_cast<__fp16>(tmp[c]);
16951695
}
16961696

1697+
static size_t cactus_quant_gemv_sb_per_thread() {
1698+
static const size_t v = [] {
1699+
const char* e = getenv("CACTUS_GEMV_SB_PER_THREAD");
1700+
const int i = e ? atoi(e) : 8;
1701+
return static_cast<size_t>(i > 0 ? i : 8);
1702+
}();
1703+
return v;
1704+
}
1705+
16971706
// SME workers live INSIDE the thread budget, replacing NEON workers (measured frontier: flat k=2
16981707
// dominates k=0 on speed AND power). Env/setter overrides; backend 2 clamps >= 1 for leaf coverage.
16991708
static inline size_t cactus_quant_panel_k_sme(size_t nt, uint32_t gs) {
@@ -1712,11 +1721,7 @@ static void cactus_quant_panel_gemv(const CactusQuantMatrix* W, const __fp16* A,
17121721
const uint32_t N = W->N;
17131722
const size_t SB64 = (static_cast<size_t>(N) + 63) / 64;
17141723
auto& pool = CactusThreading::get_thread_pool();
1715-
static const size_t sb_per_thread = [] {
1716-
const char* e = getenv("CACTUS_GEMV_SB_PER_THREAD");
1717-
const int v = e ? atoi(e) : 8;
1718-
return static_cast<size_t>(v > 0 ? v : 8);
1719-
}();
1724+
const size_t sb_per_thread = cactus_quant_gemv_sb_per_thread();
17201725
const size_t nt_budget = std::max<size_t>(1, (SB64 + sb_per_thread - 1) / sb_per_thread);
17211726
const size_t nt = std::min(pool.num_workers(), std::min(nt_budget, SB64));
17221727

@@ -1814,11 +1819,7 @@ void cactus_quant_orth_panel_gemv(const CactusQuantMatrix* W2, const __fp16* rot
18141819
const uint32_t num_groups = W2->num_groups;
18151820
const size_t SB64 = (static_cast<size_t>(N) + 63) / 64;
18161821
auto& pool = CactusThreading::get_thread_pool();
1817-
static const size_t sb_per_thread = [] {
1818-
const char* e = getenv("CACTUS_GEMV_SB_PER_THREAD");
1819-
const int v = e ? atoi(e) : 8;
1820-
return static_cast<size_t>(v > 0 ? v : 8);
1821-
}();
1822+
const size_t sb_per_thread = cactus_quant_gemv_sb_per_thread();
18221823
const size_t nt_budget = std::max<size_t>(1, (SB64 + sb_per_thread - 1) / sb_per_thread);
18231824
const size_t nt = std::max<size_t>(1, std::min(pool.num_workers(), std::min(nt_budget, SB64)));
18241825

@@ -3051,6 +3052,8 @@ static void cactus_quant_interleaved4_gemv_blocks(
30513052
}
30523053
}
30533054

3055+
// Fused IL GEMV: the panel driver's dispatch (group-stolen phase A, spin barrier, dynamic
3056+
// 16-block-chunk stealing, main as worker 0, spin-join) over the unchanged IL micro-kernel.
30543057
void cactus_quant_4bit_gemv_interleaved(
30553058
const CactusQuantMatrix* W,
30563059
const uint8_t* packed_interleaved,
@@ -3064,37 +3067,69 @@ void cactus_quant_4bit_gemv_interleaved(
30643067
if (W->group_size > 256) return;
30653068

30663069
const uint32_t gs = W->group_size;
3067-
const uint32_t pgb = cactus_quant_packed_group_bytes(4, gs);
30683070
const uint32_t num_groups = W->num_groups;
3071+
const size_t N_blocks = W->N / 4;
3072+
const size_t n_chunks = (N_blocks + 15) / 16;
3073+
auto& pool = CactusThreading::get_thread_pool();
3074+
const size_t sb_per_thread = cactus_quant_gemv_sb_per_thread();
3075+
const size_t nt_budget = std::max<size_t>(1, (n_chunks + sb_per_thread - 1) / sb_per_thread);
3076+
const size_t nt = std::min(pool.num_workers(), std::min(nt_budget, n_chunks));
30693077

3070-
thread_local std::vector<__fp16> code_basis_buf;
3071-
if (code_basis_buf.size() < W->K) code_basis_buf.resize(W->K);
3072-
cactus_quant_transform_hadamard_activations(*W, x, 1, code_basis_buf.data());
3073-
const __fp16* code_basis = code_basis_buf.data();
3074-
3075-
thread_local std::vector<int8_t> act_i8_buf;
3076-
thread_local std::vector<float> act_scales_buf;
3077-
if (act_i8_buf.size() < W->K) act_i8_buf.resize(W->K);
3078-
if (act_scales_buf.size() < num_groups) act_scales_buf.resize(num_groups);
3079-
for (uint32_t g = 0; g < num_groups; ++g) {
3080-
act_scales_buf[g] = tq_quantize_group_i8(
3081-
code_basis + static_cast<size_t>(g) * gs,
3082-
act_i8_buf.data() + static_cast<size_t>(g) * gs, gs);
3083-
}
3084-
const int8_t* act_i8 = act_i8_buf.data();
3085-
const float* act_scales = act_scales_buf.data();
3078+
static thread_local std::vector<int8_t> tl_il_act_i8;
3079+
static thread_local std::vector<float> tl_il_act_scales;
3080+
if (tl_il_act_i8.size() < W->K) tl_il_act_i8.resize(W->K);
3081+
if (tl_il_act_scales.size() < num_groups) tl_il_act_scales.resize(num_groups);
3082+
int8_t* act_i8 = tl_il_act_i8.data();
3083+
float* act_scales = tl_il_act_scales.data();
30863084

30873085
int8_t cb_i8[16] = {};
30883086
const float cb_scale = tq_quantize_codebook_i8(W->codebook, cb_i8, 16);
30893087
const int8x16_t cb_lut = vld1q_s8(cb_i8);
30903088

3091-
const size_t N_blocks = W->N / 4;
3089+
auto phase_a_group = [&](uint32_t g) {
3090+
__fp16 basis[256];
3091+
cactus_quant_transform_hadamard_group(*W, x + static_cast<size_t>(g) * gs, g, basis);
3092+
act_scales[g] = tq_quantize_group_i8(basis, act_i8 + static_cast<size_t>(g) * gs, gs);
3093+
};
30923094

3093-
cactus_quant_parallel_ranges(N_blocks, 64, [&](size_t block_start, size_t block_end) {
3095+
if (nt <= 1) {
3096+
for (uint32_t g = 0; g < num_groups; ++g) phase_a_group(g);
30943097
cactus_quant_interleaved4_gemv_blocks(W, packed_interleaved, norms_interleaved,
30953098
act_i8, act_scales, cb_lut, cb_scale,
3096-
block_start, block_end, y);
3099+
0, N_blocks, y);
3100+
return;
3101+
}
3102+
3103+
std::atomic<uint32_t> ga{0};
3104+
std::atomic<uint32_t> a_done{0};
3105+
std::atomic<uint32_t> next{0};
3106+
std::atomic<uint32_t> done{0};
3107+
auto worker = [&](size_t) {
3108+
for (uint32_t g; (g = ga.fetch_add(1, std::memory_order_relaxed)) < num_groups; ) {
3109+
phase_a_group(g);
3110+
a_done.fetch_add(1, std::memory_order_release);
3111+
}
3112+
while (a_done.load(std::memory_order_acquire) < num_groups) { /* spin */ }
3113+
for (;;) {
3114+
const uint32_t seen = next.load(std::memory_order_relaxed);
3115+
if (seen >= n_chunks) break;
3116+
const uint32_t want = (n_chunks - seen > 4u * nt) ? 4u : 1u;
3117+
const uint32_t ck = next.fetch_add(want, std::memory_order_relaxed);
3118+
if (ck >= n_chunks) break;
3119+
const uint32_t cnt = std::min<uint32_t>(want, static_cast<uint32_t>(n_chunks) - ck);
3120+
const size_t b0 = static_cast<size_t>(ck) * 16;
3121+
const size_t b1 = std::min(N_blocks, b0 + static_cast<size_t>(cnt) * 16);
3122+
cactus_quant_interleaved4_gemv_blocks(W, packed_interleaved, norms_interleaved,
3123+
act_i8, act_scales, cb_lut, cb_scale,
3124+
b0, b1, y);
3125+
}
3126+
};
3127+
pool.enqueue_n_threads(nt - 1, nt - 1, [&](size_t wid, size_t) {
3128+
worker(wid + 1);
3129+
done.fetch_add(1, std::memory_order_release);
30973130
});
3131+
worker(0);
3132+
while (done.load(std::memory_order_acquire) < nt - 1) { /* spin */ }
30983133
}
30993134

31003135
void cactus_quant_3bit_gemv_interleaved(

cactus-kernels/tests/test_matmul.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,8 @@ static bool test_orth_panel(int backend, double& mse_inc, double& mse_panel) {
563563
// legacy interleaved NEON kernel (old bundles); use_panels=true builds the panel layout from the
564564
// SAME interleaved fixture through the reference encoder and exercises the panel GEMV
565565
// (multi-super-block stealing included: N=192 = 3 super-blocks).
566-
static bool test_cq4_interleaved(bool use_panels, int backend, double& mse_out) {
567-
const uint32_t K = 1024, N = 192, gs = 128; // 192 = 3 super-blocks: exercises multi-SB stealing
566+
static bool test_cq4_interleaved(bool use_panels, int backend, double& mse_out,
567+
uint32_t K = 1024, uint32_t N = 192, uint32_t gs = 128) {
568568
SyntheticCQ cq(4, K, N, gs, 777);
569569
if (use_panels) cq.preexpand_il();
570570
CactusQuantMatrix mat = cq.matrix_interleaved();
@@ -1049,6 +1049,10 @@ int main() {
10491049
{
10501050
double m1 = 0;
10511051
runner.run_test("matmul_cq4_il[file]", test_cq4_interleaved(false, 1, m1));
1052+
// N=4164: 1041 IL blocks -> 66 chunks (16-block + 1-block tail) -> multi-thread fused
1053+
// driver (phase-A stealing, spin barrier, 4-chunk grabs); N=192 stays on the serial path.
1054+
double m_mt = 0;
1055+
runner.run_test("matmul_cq4_il_mt[file]", test_cq4_interleaved(false, 1, m_mt, 1024, 4164, 128));
10521056
runner.run_test("panel_layout_invariance", test_panel_layout_invariance());
10531057
runner.run_test("orth_embed_rows_batched", test_orth_embed_rows());
10541058
}

0 commit comments

Comments
 (0)