@@ -1694,6 +1694,15 @@ static void cactus_quant_panel_rescale_sb(const int32_t* psb, size_t sb, const f
16941694 for (uint32_t c = 0 ; c < valid; ++c) C[n_start + c] = static_cast <__fp16>(tmp[c]);
16951695}
16961696
1697+ static size_t cactus_quant_gemv_sb_per_thread () {
1698+ static const size_t v = [] {
1699+ const char * e = getenv (" CACTUS_GEMV_SB_PER_THREAD" );
1700+ const int i = e ? atoi (e) : 8 ;
1701+ return static_cast <size_t >(i > 0 ? i : 8 );
1702+ }();
1703+ return v;
1704+ }
1705+
16971706// SME workers live INSIDE the thread budget, replacing NEON workers (measured frontier: flat k=2
16981707// dominates k=0 on speed AND power). Env/setter overrides; backend 2 clamps >= 1 for leaf coverage.
16991708static inline size_t cactus_quant_panel_k_sme (size_t nt, uint32_t gs) {
@@ -1712,11 +1721,7 @@ static void cactus_quant_panel_gemv(const CactusQuantMatrix* W, const __fp16* A,
17121721 const uint32_t N = W->N ;
17131722 const size_t SB64 = (static_cast <size_t >(N) + 63 ) / 64 ;
17141723 auto & pool = CactusThreading::get_thread_pool ();
1715- static const size_t sb_per_thread = [] {
1716- const char * e = getenv (" CACTUS_GEMV_SB_PER_THREAD" );
1717- const int v = e ? atoi (e) : 8 ;
1718- return static_cast <size_t >(v > 0 ? v : 8 );
1719- }();
1724+ const size_t sb_per_thread = cactus_quant_gemv_sb_per_thread ();
17201725 const size_t nt_budget = std::max<size_t >(1 , (SB64 + sb_per_thread - 1 ) / sb_per_thread);
17211726 const size_t nt = std::min (pool.num_workers (), std::min (nt_budget, SB64 ));
17221727
@@ -1814,11 +1819,7 @@ void cactus_quant_orth_panel_gemv(const CactusQuantMatrix* W2, const __fp16* rot
18141819 const uint32_t num_groups = W2 ->num_groups ;
18151820 const size_t SB64 = (static_cast <size_t >(N) + 63 ) / 64 ;
18161821 auto & pool = CactusThreading::get_thread_pool ();
1817- static const size_t sb_per_thread = [] {
1818- const char * e = getenv (" CACTUS_GEMV_SB_PER_THREAD" );
1819- const int v = e ? atoi (e) : 8 ;
1820- return static_cast <size_t >(v > 0 ? v : 8 );
1821- }();
1822+ const size_t sb_per_thread = cactus_quant_gemv_sb_per_thread ();
18221823 const size_t nt_budget = std::max<size_t >(1 , (SB64 + sb_per_thread - 1 ) / sb_per_thread);
18231824 const size_t nt = std::max<size_t >(1 , std::min (pool.num_workers (), std::min (nt_budget, SB64 )));
18241825
@@ -3051,6 +3052,8 @@ static void cactus_quant_interleaved4_gemv_blocks(
30513052 }
30523053}
30533054
3055+ // Fused IL GEMV: the panel driver's dispatch (group-stolen phase A, spin barrier, dynamic
3056+ // 16-block-chunk stealing, main as worker 0, spin-join) over the unchanged IL micro-kernel.
30543057void cactus_quant_4bit_gemv_interleaved (
30553058 const CactusQuantMatrix* W,
30563059 const uint8_t * packed_interleaved,
@@ -3064,37 +3067,69 @@ void cactus_quant_4bit_gemv_interleaved(
30643067 if (W->group_size > 256 ) return ;
30653068
30663069 const uint32_t gs = W->group_size ;
3067- const uint32_t pgb = cactus_quant_packed_group_bytes (4 , gs);
30683070 const uint32_t num_groups = W->num_groups ;
3071+ const size_t N_blocks = W->N / 4 ;
3072+ const size_t n_chunks = (N_blocks + 15 ) / 16 ;
3073+ auto & pool = CactusThreading::get_thread_pool ();
3074+ const size_t sb_per_thread = cactus_quant_gemv_sb_per_thread ();
3075+ const size_t nt_budget = std::max<size_t >(1 , (n_chunks + sb_per_thread - 1 ) / sb_per_thread);
3076+ const size_t nt = std::min (pool.num_workers (), std::min (nt_budget, n_chunks));
30693077
3070- thread_local std::vector<__fp16> code_basis_buf;
3071- if (code_basis_buf.size () < W->K ) code_basis_buf.resize (W->K );
3072- cactus_quant_transform_hadamard_activations (*W, x, 1 , code_basis_buf.data ());
3073- const __fp16* code_basis = code_basis_buf.data ();
3074-
3075- thread_local std::vector<int8_t > act_i8_buf;
3076- thread_local std::vector<float > act_scales_buf;
3077- if (act_i8_buf.size () < W->K ) act_i8_buf.resize (W->K );
3078- if (act_scales_buf.size () < num_groups) act_scales_buf.resize (num_groups);
3079- for (uint32_t g = 0 ; g < num_groups; ++g) {
3080- act_scales_buf[g] = tq_quantize_group_i8 (
3081- code_basis + static_cast <size_t >(g) * gs,
3082- act_i8_buf.data () + static_cast <size_t >(g) * gs, gs);
3083- }
3084- const int8_t * act_i8 = act_i8_buf.data ();
3085- const float * act_scales = act_scales_buf.data ();
3078+ static thread_local std::vector<int8_t > tl_il_act_i8;
3079+ static thread_local std::vector<float > tl_il_act_scales;
3080+ if (tl_il_act_i8.size () < W->K ) tl_il_act_i8.resize (W->K );
3081+ if (tl_il_act_scales.size () < num_groups) tl_il_act_scales.resize (num_groups);
3082+ int8_t * act_i8 = tl_il_act_i8.data ();
3083+ float * act_scales = tl_il_act_scales.data ();
30863084
30873085 int8_t cb_i8[16 ] = {};
30883086 const float cb_scale = tq_quantize_codebook_i8 (W->codebook , cb_i8, 16 );
30893087 const int8x16_t cb_lut = vld1q_s8 (cb_i8);
30903088
3091- const size_t N_blocks = W->N / 4 ;
3089+ auto phase_a_group = [&](uint32_t g) {
3090+ __fp16 basis[256 ];
3091+ cactus_quant_transform_hadamard_group (*W, x + static_cast <size_t >(g) * gs, g, basis);
3092+ act_scales[g] = tq_quantize_group_i8 (basis, act_i8 + static_cast <size_t >(g) * gs, gs);
3093+ };
30923094
3093- cactus_quant_parallel_ranges (N_blocks, 64 , [&](size_t block_start, size_t block_end) {
3095+ if (nt <= 1 ) {
3096+ for (uint32_t g = 0 ; g < num_groups; ++g) phase_a_group (g);
30943097 cactus_quant_interleaved4_gemv_blocks (W, packed_interleaved, norms_interleaved,
30953098 act_i8, act_scales, cb_lut, cb_scale,
3096- block_start, block_end, y);
3099+ 0 , N_blocks, y);
3100+ return ;
3101+ }
3102+
3103+ std::atomic<uint32_t > ga{0 };
3104+ std::atomic<uint32_t > a_done{0 };
3105+ std::atomic<uint32_t > next{0 };
3106+ std::atomic<uint32_t > done{0 };
3107+ auto worker = [&](size_t ) {
3108+ for (uint32_t g; (g = ga.fetch_add (1 , std::memory_order_relaxed)) < num_groups; ) {
3109+ phase_a_group (g);
3110+ a_done.fetch_add (1 , std::memory_order_release);
3111+ }
3112+ while (a_done.load (std::memory_order_acquire) < num_groups) { /* spin */ }
3113+ for (;;) {
3114+ const uint32_t seen = next.load (std::memory_order_relaxed);
3115+ if (seen >= n_chunks) break ;
3116+ const uint32_t want = (n_chunks - seen > 4u * nt) ? 4u : 1u ;
3117+ const uint32_t ck = next.fetch_add (want, std::memory_order_relaxed);
3118+ if (ck >= n_chunks) break ;
3119+ const uint32_t cnt = std::min<uint32_t >(want, static_cast <uint32_t >(n_chunks) - ck);
3120+ const size_t b0 = static_cast <size_t >(ck) * 16 ;
3121+ const size_t b1 = std::min (N_blocks, b0 + static_cast <size_t >(cnt) * 16 );
3122+ cactus_quant_interleaved4_gemv_blocks (W, packed_interleaved, norms_interleaved,
3123+ act_i8, act_scales, cb_lut, cb_scale,
3124+ b0, b1, y);
3125+ }
3126+ };
3127+ pool.enqueue_n_threads (nt - 1 , nt - 1 , [&](size_t wid, size_t ) {
3128+ worker (wid + 1 );
3129+ done.fetch_add (1 , std::memory_order_release);
30973130 });
3131+ worker (0 );
3132+ while (done.load (std::memory_order_acquire) < nt - 1 ) { /* spin */ }
30983133}
30993134
31003135void cactus_quant_3bit_gemv_interleaved (
0 commit comments