diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 005788d3..954a7408 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -33,7 +33,7 @@ What did you run to convince yourself this works? Be specific. For inference changes, real-checkpoint validation is required — synthetic-only is not enough. --> -- [ ] `cargo fmt --check` +- [ ] `cargo fmt --all -- --check` (enforced by CI — violations block merge) - [ ] `cargo clippy --all-targets -- -D warnings` - [ ] `cargo test --release` - [ ] `cargo deny check` diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2b527ed3..e5276a4a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,8 +1,8 @@ -# PR / push CI for security audit. Builds and tests are covered by +# PR / push CI for lightweight quality gates. Builds and tests are covered by # release.yml (signed artifacts) and pipeline-parallel-ci.yml (distributed -# runs); this workflow runs the lightweight cargo-deny gate on every -# touched-Rust change so license drift and new advisories are caught at -# PR time. +# runs); this workflow runs cargo-deny (license / advisory) and cargo-fmt +# checks on every touched-Rust change so formatting drift and license issues +# are caught at PR time. name: CI @@ -54,3 +54,19 @@ jobs: with: command: check log-level: warn + + fmt: + name: cargo-fmt + needs: changes + if: needs.changes.outputs.rust == 'true' + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - uses: actions/checkout@v6 + with: + persist-credentials: false + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - run: cargo fmt --all -- --check diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1b0af1f9..b0ffa3d1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -41,9 +41,9 @@ Thank you for your interest in contributing to mlxcel! This document covers the ``` 4. Run the local quality gates: ```bash - cargo fmt --check + cargo fmt --all -- --check # enforced by CI; fmt violations block merge cargo clippy --all-targets -- -D warnings - cargo deny check # advisories + licenses + sources + cargo deny check # advisories + licenses + sources ``` 5. For inference changes, validate against a real checkpoint — synthetic or build-only validation is not enough (see [`AGENTS.md`](AGENTS.md) for why). 6. Commit with a conventional prefix (see below) and a clear message. diff --git a/src/lib/mlxcel-core/src/cache.rs b/src/lib/mlxcel-core/src/cache.rs index 83d2645c..739d0056 100644 --- a/src/lib/mlxcel-core/src/cache.rs +++ b/src/lib/mlxcel-core/src/cache.rs @@ -2373,10 +2373,8 @@ impl KVCache { let k_slice = ffi::slice(k_int8, &[0, 0, 0, 0], &[ks[0], ks[1], live_len, ks[3]]); let v_slice = ffi::slice(v_int8, &[0, 0, 0, 0], &[vs[0], vs[1], live_len, vs[3]]); - let ks_slice = - ffi::slice(k_scales, &[0, 0, 0, 0], &[kss[0], kss[1], live_len, 1]); - let vs_slice = - ffi::slice(v_scales, &[0, 0, 0, 0], &[vss[0], vss[1], live_len, 1]); + let ks_slice = ffi::slice(k_scales, &[0, 0, 0, 0], &[kss[0], kss[1], live_len, 1]); + let vs_slice = ffi::slice(v_scales, &[0, 0, 0, 0], &[vss[0], vss[1], live_len, 1]); ( dequantize(&k_slice, &ks_slice), @@ -3785,8 +3783,7 @@ impl RotatingKVCache { return (new_keys, new_values); } - let (base_k, base_v, current_seq_len) = - self.visible_fp16_prefix_for_concat(); + let (base_k, base_v, current_seq_len) = self.visible_fp16_prefix_for_concat(); let concat_k = concatenate(&base_k, &new_keys, 2); let concat_v = concatenate(&base_v, &new_values, 2); @@ -5681,10 +5678,7 @@ mod tests { .collect::>() }; assert_eq!(to_f32(&visible_keys), vec![1.0, 5.0, 6.0, 7.0, 8.0]); - assert_eq!( - to_f32(&visible_values), - vec![10.0, 50.0, 60.0, 70.0, 80.0] - ); + assert_eq!(to_f32(&visible_values), vec![10.0, 50.0, 60.0, 70.0, 80.0]); } #[test] @@ -6190,10 +6184,8 @@ mod tests { } let (q_unrot, _) = unit_token(42); let q_ref = rotate_at(&q_unrot, M); - let (k_ref, v_ref) = cache_ref.update_and_fetch( - rotate_at(&unit_token(99).0, M), - unit_token(99).1, - ); + let (k_ref, v_ref) = + cache_ref.update_and_fetch(rotate_at(&unit_token(99).0, M), unit_token(99).1); let out_ref = mlxcel_core::causal_attention(&q_ref, &k_ref, &v_ref, scale, 0.0, 0); let out_ref_f32 = to_f32(&out_ref); @@ -6216,11 +6208,10 @@ mod tests { // position `M` to simulate the pre-fix offset decrement. assert_eq!(cache_broken.trim_front(N), N); let q_broken = rotate_at(&q_unrot, M); - let (k_broken, v_broken) = cache_broken.update_and_fetch( - rotate_at(&unit_token(99).0, M), - unit_token(99).1, - ); - let out_broken = mlxcel_core::causal_attention(&q_broken, &k_broken, &v_broken, scale, 0.0, 0); + let (k_broken, v_broken) = + cache_broken.update_and_fetch(rotate_at(&unit_token(99).0, M), unit_token(99).1); + let out_broken = + mlxcel_core::causal_attention(&q_broken, &k_broken, &v_broken, scale, 0.0, 0); let out_broken_f32 = to_f32(&out_broken); let mut sq_err = 0.0_f64; diff --git a/src/lib/mlxcel-core/src/cache/detach.rs b/src/lib/mlxcel-core/src/cache/detach.rs index 5741a6b4..e5691767 100644 --- a/src/lib/mlxcel-core/src/cache/detach.rs +++ b/src/lib/mlxcel-core/src/cache/detach.rs @@ -241,19 +241,14 @@ impl DetachedKVCache { // Slice axis 2 (seq-len) of each tensor to the new length. Width axes // (B, H, head_dim / packed_dim) come from the existing shape. - let trim_axis_seq = |a: &Option>, - tail_axis: i32| - -> Option> { - a.as_ref().map(|arr| { - let shape = ffi::array_shape(arr); - let last = if tail_axis == 0 { shape[3] } else { tail_axis }; - ffi::slice( - arr, - &[0, 0, 0, 0], - &[shape[0], shape[1], new_len, last], - ) - }) - }; + let trim_axis_seq = + |a: &Option>, tail_axis: i32| -> Option> { + a.as_ref().map(|arr| { + let shape = ffi::array_shape(arr); + let last = if tail_axis == 0 { shape[3] } else { tail_axis }; + ffi::slice(arr, &[0, 0, 0, 0], &[shape[0], shape[1], new_len, last]) + }) + }; if self.mode == KVCacheMode::Turbo4Delegated { // K is unified — same shape contract as Fp16. Slice to `new_len`. diff --git a/src/lib/mlxcel-core/src/cache/detach_tests.rs b/src/lib/mlxcel-core/src/cache/detach_tests.rs index c3543048..a0defb3a 100644 --- a/src/lib/mlxcel-core/src/cache/detach_tests.rs +++ b/src/lib/mlxcel-core/src/cache/detach_tests.rs @@ -677,10 +677,7 @@ fn detached_kvcache_trim_to_mid_buffer_resets_offset_and_visible_keys() { #[test] fn detached_kvcache_trim_to_zero_drops_storage_and_keeps_mode() { let mut live = KVCache::new_with_mode(KVCacheMode::Int8); - live.update( - fp32_tokens(&[1.0, 2.0, 3.0]), - fp32_tokens(&[4.0, 5.0, 6.0]), - ); + live.update(fp32_tokens(&[1.0, 2.0, 3.0]), fp32_tokens(&[4.0, 5.0, 6.0])); let mut handle = live.clone_handle(); assert_eq!(handle.seq_len(), 3); assert_eq!(handle.mode(), KVCacheMode::Int8); @@ -698,7 +695,9 @@ fn detached_kvcache_trim_to_exact_offset_is_noop() { live.update(fp32_tokens(&[1.0, 2.0]), fp32_tokens(&[3.0, 4.0])); let mut handle = live.clone_handle(); - handle.trim_to(2).expect("trim_to exact offset must succeed"); + handle + .trim_to(2) + .expect("trim_to exact offset must succeed"); assert_eq!(handle.seq_len(), 2); let mut restored = KVCache::new(); @@ -825,7 +824,9 @@ fn detached_cache_set_truncate_to_zero_drops_every_layer() { let v: Vec = (1..=4).map(|i| i as f32 * 10.0).collect(); let mut detached = detached_set_with_fp16_layers(2, &k, &v); - detached.truncate_to(0).expect("truncate_to zero must succeed"); + detached + .truncate_to(0) + .expect("truncate_to zero must succeed"); assert!(detached.caches.iter().all(|c| c.is_empty())); assert_eq!(detached.current_offset, 0); assert_eq!(detached.prompt_len, 0); @@ -940,10 +941,7 @@ fn detached_cache_set_truncate_to_int8_preserves_dequantization() { let seq_adopted = pool.adopt(&model, detached).unwrap(); let (k_a, v_a) = { let caches = pool.get_caches_mut(seq_adopted).unwrap(); - caches[0].update_and_fetch( - fp32_tokens(&full_k[5..]), - fp32_tokens(&full_v[5..]), - ) + caches[0].update_and_fetch(fp32_tokens(&full_k[5..]), fp32_tokens(&full_v[5..])) }; eval(&k_a); eval(&v_a); diff --git a/src/lib/mlxcel-core/src/drafter/dflash/attention.rs b/src/lib/mlxcel-core/src/drafter/dflash/attention.rs index dd78e500..3ffc7167 100644 --- a/src/lib/mlxcel-core/src/drafter/dflash/attention.rs +++ b/src/lib/mlxcel-core/src/drafter/dflash/attention.rs @@ -441,4 +441,3 @@ mod tests { ); } } - diff --git a/src/lib/mlxcel-core/src/drafter/dflash/drafter.rs b/src/lib/mlxcel-core/src/drafter/dflash/drafter.rs index ba4972e3..ca3ee0b4 100644 --- a/src/lib/mlxcel-core/src/drafter/dflash/drafter.rs +++ b/src/lib/mlxcel-core/src/drafter/dflash/drafter.rs @@ -106,12 +106,11 @@ impl DFlashDrafter { path: config_path.display().to_string(), source: e, })?; - let config = DFlashConfig::from_json(&config_json).map_err(|e| { - DrafterError::ConfigParse { + let config = + DFlashConfig::from_json(&config_json).map_err(|e| DrafterError::ConfigParse { path: config_path.display().to_string(), source: serde::de::Error::custom(e), - } - })?; + })?; let mut weights = crate::weights::load_weights_from_dir(path) .map_err(|msg| DrafterError::LoadFailed { reason: msg })?; @@ -209,8 +208,9 @@ impl Drafter for DFlashDrafter { // expose `embed_tokens` (e.g. a non-Qwen-3.5 model fed in by // mistake). if self.model.needs_embed_binding() { - let embed = target.embed_tokens_module().ok_or_else(|| { - DrafterError::BindFailed { + let embed = target + .embed_tokens_module() + .ok_or_else(|| DrafterError::BindFailed { reason: format!( "DFlash drafter checkpoint omits embed_tokens.weight \ and the target does not expose embed_tokens_module(); \ @@ -220,8 +220,7 @@ impl Drafter for DFlashDrafter { (kind = {})", self.kind() ), - } - })?; + })?; self.model.bind_target_embedding(embed); } else { // Legacy capability smoke-test for self-contained checkpoints: @@ -427,10 +426,7 @@ fn sample_block_per_position_batched( sampler: &SamplingConfig, ) -> Result>, DrafterError> { let shape = ffi::array_shape(logits); - if shape.len() != 3 - || shape[0] != batch_size as i32 - || shape[1] != block_size as i32 - { + if shape.len() != 3 || shape[0] != batch_size as i32 || shape[1] != block_size as i32 { return Err(DrafterError::DraftFailed { reason: format!( "DFlash drafter (batched) expected logits shape \ @@ -468,11 +464,7 @@ fn sample_block_per_position_batched( for i in 0..n { // Row `(b, i+1)` of the [B, L, V] logits. let pos = (i + 1) as i32; - let row = ffi::slice( - logits, - &[b, pos, 0_i32], - &[b + 1, pos + 1, vocab], - ); + let row = ffi::slice(logits, &[b, pos, 0_i32], &[b + 1, pos + 1, vocab]); // Drop the seq axis so we get a `[1, vocab]` 2D slice (fused_sample // / argmax expect `[batch, vocab]`). let row = ffi::reshape(&row, &[1_i32, vocab]); @@ -536,7 +528,11 @@ fn sample_block_per_position( for i in 0..n { // Row `i + 1` of the [1, L, V] logits. let row_idx = (i + 1) as i32; - let row = ffi::slice(logits, &[0_i32, row_idx, 0_i32], &[1_i32, row_idx + 1, vocab]); + let row = ffi::slice( + logits, + &[0_i32, row_idx, 0_i32], + &[1_i32, row_idx + 1, vocab], + ); // Drop the seq axis so we get a `[1, vocab]` 2D slice (fused_sample // / argmax expect `[batch, vocab]`). let row = ffi::reshape(&row, &[1_i32, vocab]); @@ -588,10 +584,7 @@ fn sample_block_per_position_array( } let tokens = sample_block_per_position(logits, block_size, sampler)?; - Ok(ffi::from_slice_i32( - &tokens, - &[1, (block_size - 1) as i32], - )) + Ok(ffi::from_slice_i32(&tokens, &[1, (block_size - 1) as i32])) } #[cfg(test)] @@ -620,10 +613,7 @@ mod tests { ffi::zeros(&[4, 4], dtype::BFLOAT16), ); // A non-bf16 tensor: should pass through. - weights.insert( - "fc.weight".to_string(), - ffi::zeros(&[4, 4], dtype::FLOAT16), - ); + weights.insert("fc.weight".to_string(), ffi::zeros(&[4, 4], dtype::FLOAT16)); convert_bf16_to_f16_non_quantized(&mut weights); diff --git a/src/lib/mlxcel-core/src/drafter/dflash/layer.rs b/src/lib/mlxcel-core/src/drafter/dflash/layer.rs index 5ec75826..4bde955b 100644 --- a/src/lib/mlxcel-core/src/drafter/dflash/layer.rs +++ b/src/lib/mlxcel-core/src/drafter/dflash/layer.rs @@ -83,18 +83,13 @@ impl DFlashDecoderLayer { let post_attention_layernorm_w = weights .get(&format!("{prefix}.post_attention_layernorm.weight")) .map(|w| ffi::copy(w)) - .ok_or_else(|| { - format!("Weight not found: {prefix}.post_attention_layernorm.weight") - })?; + .ok_or_else(|| format!("Weight not found: {prefix}.post_attention_layernorm.weight"))?; Ok(Self { self_attn, mlp, input_layernorm: RMSNorm::new(input_layernorm_w, config.rms_norm_eps), - post_attention_layernorm: RMSNorm::new( - post_attention_layernorm_w, - config.rms_norm_eps, - ), + post_attention_layernorm: RMSNorm::new(post_attention_layernorm_w, config.rms_norm_eps), }) } } diff --git a/src/lib/mlxcel-core/src/drafter/dflash/mlp.rs b/src/lib/mlxcel-core/src/drafter/dflash/mlp.rs index 63b83916..05eec477 100644 --- a/src/lib/mlxcel-core/src/drafter/dflash/mlp.rs +++ b/src/lib/mlxcel-core/src/drafter/dflash/mlp.rs @@ -67,24 +67,12 @@ impl DFlashMlp { group_size: i32, bits: i32, ) -> Result { - let gate_proj = UnifiedLinear::from_weights( - weights, - &format!("{prefix}.gate_proj"), - group_size, - bits, - )?; - let up_proj = UnifiedLinear::from_weights( - weights, - &format!("{prefix}.up_proj"), - group_size, - bits, - )?; - let down_proj = UnifiedLinear::from_weights( - weights, - &format!("{prefix}.down_proj"), - group_size, - bits, - )?; + let gate_proj = + UnifiedLinear::from_weights(weights, &format!("{prefix}.gate_proj"), group_size, bits)?; + let up_proj = + UnifiedLinear::from_weights(weights, &format!("{prefix}.up_proj"), group_size, bits)?; + let down_proj = + UnifiedLinear::from_weights(weights, &format!("{prefix}.down_proj"), group_size, bits)?; Ok(Self { gate_proj, up_proj, diff --git a/src/lib/mlxcel-core/src/drafter/dflash/mod.rs b/src/lib/mlxcel-core/src/drafter/dflash/mod.rs index a94f08a2..b0db0888 100644 --- a/src/lib/mlxcel-core/src/drafter/dflash/mod.rs +++ b/src/lib/mlxcel-core/src/drafter/dflash/mod.rs @@ -104,8 +104,7 @@ pub(crate) fn materialize_argmax_i32_vec(argmax: &MlxArray, expected_len: usize) .take(expected_len) .map(|chunk| { i64::from_ne_bytes([ - chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], - chunk[7], + chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7], ]) as i32 }) .collect(), diff --git a/src/lib/mlxcel-core/src/drafter/dflash/model.rs b/src/lib/mlxcel-core/src/drafter/dflash/model.rs index 2be89960..5bea0ce3 100644 --- a/src/lib/mlxcel-core/src/drafter/dflash/model.rs +++ b/src/lib/mlxcel-core/src/drafter/dflash/model.rs @@ -464,8 +464,14 @@ mod tests { fn tiny_weights_without_embed() -> WeightMap { let mut w: WeightMap = std::collections::HashMap::new(); // fc: [hidden, num_target_layers * hidden] = [4, 16] - w.insert("fc.weight".to_string(), ffi::zeros(&[4, 16], dtype::FLOAT32)); - w.insert("hidden_norm.weight".to_string(), ffi::zeros(&[4], dtype::FLOAT32)); + w.insert( + "fc.weight".to_string(), + ffi::zeros(&[4, 16], dtype::FLOAT32), + ); + w.insert( + "hidden_norm.weight".to_string(), + ffi::zeros(&[4], dtype::FLOAT32), + ); w.insert("norm.weight".to_string(), ffi::zeros(&[4], dtype::FLOAT32)); // Layer 0 projections. q out = n_heads*head_dim = 4; k/v out = // n_kv_heads*head_dim = 2; o in = 4. @@ -593,9 +599,9 @@ mod tests { // Stand-in for the target's embedding table — a regular // [vocab, hidden] tensor, the same shape Qwen 3.5 hands out. - let target_embed = crate::layers::UnifiedEmbedding::Regular( - crate::layers::Embedding::new(ffi::zeros(&[8, 4], dtype::FLOAT32)), - ); + let target_embed = crate::layers::UnifiedEmbedding::Regular(crate::layers::Embedding::new( + ffi::zeros(&[8, 4], dtype::FLOAT32), + )); model.bind_target_embedding(target_embed); assert!( diff --git a/src/lib/mlxcel-core/src/drafter/dflash/round_loop.rs b/src/lib/mlxcel-core/src/drafter/dflash/round_loop.rs index 1c3a6761..4ebc1d33 100644 --- a/src/lib/mlxcel-core/src/drafter/dflash/round_loop.rs +++ b/src/lib/mlxcel-core/src/drafter/dflash/round_loop.rs @@ -262,10 +262,7 @@ pub trait SpeculativeTarget { /// `hidden = hidden[:, :max_a + 1, :]`); on full-accept rounds the /// loop forwards the full block. The target trait stays simple and /// the slice logic lives in one place. - fn concat_hidden_for_drafter( - &self, - verify_out: &Self::VerifyOut, - ) -> UniquePtr; + fn concat_hidden_for_drafter(&self, verify_out: &Self::VerifyOut) -> UniquePtr; /// Read the per-position logits out of `verify_out` for use by the /// round loop's argmax pass. Returned tensor has shape @@ -418,7 +415,9 @@ pub(crate) fn speculative_walk( // The target tensor has one more position than the drafter's proposals // (the trailing position is the post-block argmax, which the walk // takes as the bonus token only if the entire prefix matched). - let n = draft_tokens.len().min(target_tokens.len().saturating_sub(0)); + let n = draft_tokens + .len() + .min(target_tokens.len().saturating_sub(0)); let mut accepted = 0; while accepted < n { // Guard against target_tokens being shorter than expected (degenerate @@ -740,7 +739,8 @@ impl DFlashGenerator { // preparation overlap the host-side speculative walk. let hidden_phase_start = Instant::now(); let full_hidden = target.concat_hidden_for_drafter(&verify_out); - diagnostics.hidden_concat_time_ms += hidden_phase_start.elapsed().as_secs_f64() * 1000.0; + diagnostics.hidden_concat_time_ms += + hidden_phase_start.elapsed().as_secs_f64() * 1000.0; // SAFETY: both arrays are live for the duration of this call, and // the FFI bridge consumes the stack pointer slice synchronously to // schedule MLX evaluation. @@ -774,8 +774,7 @@ impl DFlashGenerator { // ---- Walk ---- let budget = max_tokens.saturating_sub(emitted); let phase_start = Instant::now(); - let (accepted, new_tokens) = - speculative_walk(&draft_tokens, &target_tokens, budget); + let (accepted, new_tokens) = speculative_walk(&draft_tokens, &target_tokens, budget); diagnostics.walk_time_ms += phase_start.elapsed().as_secs_f64() * 1000.0; self.accept_lens.push(accepted as u32); diagnostics.rounds += 1; @@ -806,9 +805,7 @@ impl DFlashGenerator { // `new_tokens.len() <= bs`, so `i` is always in // range — a missing entry degrades to `None` rather // than panicking. - let lp = target_logprobs - .as_ref() - .and_then(|v| v.get(i).cloned()); + let lp = target_logprobs.as_ref().and_then(|v| v.get(i).cloned()); logprobs.push(lp); } emitted += 1; @@ -971,8 +968,7 @@ fn per_position_logprobs( let shape = ffi::array_shape(logits); debug_assert!(shape.len() == 3, "expected [1, seq_len, vocab] logits"); let vocab = shape[2]; - let mut out: Vec = - Vec::with_capacity(target_tokens.len()); + let mut out: Vec = Vec::with_capacity(target_tokens.len()); for (pos, &tok) in target_tokens.iter().enumerate() { // Slice position `pos` to a `[1, vocab]` tensor — the shape // `compute_logprobs` expects. @@ -1305,10 +1301,7 @@ mod tests { self.rollback_events.set(ev); } - fn concat_hidden_for_drafter( - &self, - verify_out: &Self::VerifyOut, - ) -> UniquePtr { + fn concat_hidden_for_drafter(&self, verify_out: &Self::VerifyOut) -> UniquePtr { // Return the full captured hidden tensor; the round loop // does its own axis-1 slice on partial accept. ffi::slice( @@ -1456,16 +1449,11 @@ mod tests { propose_fn: impl FnMut(i32, usize) -> Vec + 'static, ) -> (DFlashRunOutput, Vec<(i32, i32)>, Vec) { let target = SyntheticTarget::new(vec![1, 8, 15, 22, 29], 5 * 8, argmax_fn); - let mut caches: Vec = (0..3) - .map(|_| SyntheticCache::default()) - .collect(); + let mut caches: Vec = (0..3).map(|_| SyntheticCache::default()).collect(); let drafter = SyntheticDrafter::new(propose_fn); let lm = EmbedOnlyLm; - let mut gen = DFlashGenerator::with_drafter( - Box::new(drafter), - SamplingConfig::greedy(), - ); + let mut gen = DFlashGenerator::with_drafter(Box::new(drafter), SamplingConfig::greedy()); // Round loop pulls block_size from the generator. gen.block_size = block_size; @@ -1505,17 +1493,15 @@ mod tests { #[test] fn round_loop_passes_drafter_target_layer_ids_to_verify() { let drafter_ids = vec![1, 16, 31, 46, 61]; - let target = SyntheticTarget::new( - vec![1, 8, 15, 22, 29], - 5 * 8, - |_s: i32, prev_token: i32| prev_token + 1, - ); + let target = + SyntheticTarget::new(vec![1, 8, 15, 22, 29], 5 * 8, |_s: i32, prev_token: i32| { + prev_token + 1 + }); let mut caches: Vec = (0..3).map(|_| SyntheticCache::default()).collect(); - let drafter = SyntheticDrafter::new(|bonus, bs| { - (1..bs as i32).map(|s| bonus + s).collect() - }) - .with_target_layer_ids(drafter_ids.clone()); + let drafter = + SyntheticDrafter::new(|bonus, bs| (1..bs as i32).map(|s| bonus + s).collect()) + .with_target_layer_ids(drafter_ids.clone()); let lm = EmbedOnlyLm; let mut gen = DFlashGenerator::new( Box::new(drafter), @@ -1561,16 +1547,11 @@ mod tests { #[test] fn round_loop_full_accept_every_round_skips_rollback() { let argmax_fn = |_s: i32, prev_token: i32| prev_token + 1; - let propose_fn = |bonus: i32, bs: usize| -> Vec { - (1..bs as i32).map(|s| bonus + s).collect() - }; + let propose_fn = + |bonus: i32, bs: usize| -> Vec { (1..bs as i32).map(|s| bonus + s).collect() }; let (out, rollback_events, verify_lens) = run_synthetic_round_loop( - 8, - /*max_tokens=*/ 24, - /*first_bonus=*/ 100, - argmax_fn, - propose_fn, + 8, /*max_tokens=*/ 24, /*first_bonus=*/ 100, argmax_fn, propose_fn, ); // Each round must have accepted exactly `block_size - 1 = 7`. @@ -1625,11 +1606,7 @@ mod tests { }; let (out, rollback_events, _) = run_synthetic_round_loop( - 8, - /*max_tokens=*/ 32, - /*first_bonus=*/ 100, - argmax_fn, - propose_fn, + 8, /*max_tokens=*/ 32, /*first_bonus=*/ 100, argmax_fn, propose_fn, ); for (i, acc) in out.accept_lens.iter().enumerate() { @@ -1678,11 +1655,7 @@ mod tests { }; let (out, rollback_events, _) = run_synthetic_round_loop( - 8, - /*max_tokens=*/ 24, - /*first_bonus=*/ 100, - argmax_fn, - propose_fn, + 8, /*max_tokens=*/ 24, /*first_bonus=*/ 100, argmax_fn, propose_fn, ); assert!( @@ -1697,10 +1670,7 @@ mod tests { // Rollback called on rounds 1 and 2 (partial accept) but NOT // on round 3 (full accept). Count only partial-accept events // (`accepted < block_size - 1`). - let partial_count = rollback_events - .iter() - .filter(|(a, b)| *a < *b - 1) - .count(); + let partial_count = rollback_events.iter().filter(|(a, b)| *a < *b - 1).count(); assert!( partial_count >= 2, "rollback should fire on partial-accept rounds 1 and 2; got {rollback_events:?}" @@ -1733,11 +1703,7 @@ mod tests { }; let (_out, rollback_events, _) = run_synthetic_round_loop( - 8, - /*max_tokens=*/ 24, - /*first_bonus=*/ 100, - argmax_fn, - propose_fn, + 8, /*max_tokens=*/ 24, /*first_bonus=*/ 100, argmax_fn, propose_fn, ); // Exactly one rollback event (round 2), not two or three. @@ -1779,7 +1745,7 @@ mod tests { let first_bonus = 100i32; let max_tokens = 33; // 1 first_bonus + 32 round-loop emissions - // Build the reference greedy sequence. + // Build the reference greedy sequence. let mut reference: Vec = Vec::with_capacity(max_tokens); reference.push(first_bonus); for _ in 1..max_tokens { @@ -1792,13 +1758,8 @@ mod tests { // can produce ([30, 230)). let propose_always_wrong = |_bonus: i32, bs: usize| -> Vec { (1..bs as i32).map(|_| 0).collect() }; - let (out, _, _) = run_synthetic_round_loop( - 8, - max_tokens, - first_bonus, - argmax_fn, - propose_always_wrong, - ); + let (out, _, _) = + run_synthetic_round_loop(8, max_tokens, first_bonus, argmax_fn, propose_always_wrong); let reference_tail = &reference[1..]; assert_eq!( @@ -1877,12 +1838,10 @@ mod tests { #[test] fn round_loop_stops_on_eos_emission() { let argmax_fn = |_s: i32, prev: i32| prev + 1; - let propose_fn = |bonus: i32, bs: usize| -> Vec { - (1..bs as i32).map(|s| bonus + s).collect() - }; + let propose_fn = + |bonus: i32, bs: usize| -> Vec { (1..bs as i32).map(|s| bonus + s).collect() }; let target = SyntheticTarget::new(vec![1, 8, 15, 22, 29], 5 * 8, argmax_fn); - let mut caches: Vec = - (0..3).map(|_| SyntheticCache::default()).collect(); + let mut caches: Vec = (0..3).map(|_| SyntheticCache::default()).collect(); let drafter = SyntheticDrafter::new(propose_fn); let lm = EmbedOnlyLm; let mut gen = DFlashGenerator::with_drafter(Box::new(drafter), SamplingConfig::greedy()); @@ -1914,13 +1873,16 @@ mod tests { #[test] fn round_loop_max_tokens_one_emits_nothing() { let argmax_fn = |_s: i32, _prev: i32| 99; - let propose_fn = |_bonus: i32, bs: usize| -> Vec { - (1..bs as i32).map(|_| 0).collect() - }; - let (out, rollback_events, verify_lens) = - run_synthetic_round_loop(8, /*max_tokens=*/ 1, /*first_bonus=*/ 100, argmax_fn, propose_fn); + let propose_fn = + |_bonus: i32, bs: usize| -> Vec { (1..bs as i32).map(|_| 0).collect() }; + let (out, rollback_events, verify_lens) = run_synthetic_round_loop( + 8, /*max_tokens=*/ 1, /*first_bonus=*/ 100, argmax_fn, propose_fn, + ); - assert!(out.tokens.is_empty(), "max_tokens=1 must emit no further tokens"); + assert!( + out.tokens.is_empty(), + "max_tokens=1 must emit no further tokens" + ); assert!(out.accept_lens.is_empty()); assert!(rollback_events.is_empty()); assert!(verify_lens.is_empty()); diff --git a/src/lib/mlxcel-core/src/drafter/dflash/round_loop_batched.rs b/src/lib/mlxcel-core/src/drafter/dflash/round_loop_batched.rs index bdb0b10c..f17a9e3b 100644 --- a/src/lib/mlxcel-core/src/drafter/dflash/round_loop_batched.rs +++ b/src/lib/mlxcel-core/src/drafter/dflash/round_loop_batched.rs @@ -165,11 +165,7 @@ pub(crate) fn speculative_walk_batched( /// Mirrors upstream `sampler(verify_out.logits)` with the greedy /// `sampler = argmax(axis=-1)` — stochastic batched DFlash sampling is /// a follow-up (#632). -fn argmax_logits_per_row( - logits: &MlxArray, - batch_size: i32, - seq_len: i32, -) -> Vec> { +fn argmax_logits_per_row(logits: &MlxArray, batch_size: i32, seq_len: i32) -> Vec> { let shape = ffi::array_shape(logits); debug_assert_eq!(shape.len(), 3, "expected [B, seq_len, vocab] logits"); debug_assert_eq!(shape[0], batch_size, "logits batch dim must match B"); @@ -448,10 +444,12 @@ impl DFlashBatchedGenerator { verify_buf.push(b[r]); verify_buf.extend_from_slice(&draft_tokens_per_row[r]); } - let verify_input = - ffi::from_slice_i32(&verify_buf, &[batch_size as i32, bs as i32]); - let verify_out = - target.verify_forward_with_capture_layers(&verify_input, caches, &capture_layer_ids); + let verify_input = ffi::from_slice_i32(&verify_buf, &[batch_size as i32, bs as i32]); + let verify_out = target.verify_forward_with_capture_layers( + &verify_input, + caches, + &capture_layer_ids, + ); // ---- Argmax sample (greedy at temp=0) of the per-row logits ---- let target_tokens_per_row = argmax_logits_per_row( @@ -464,11 +462,8 @@ impl DFlashBatchedGenerator { let budgets: Vec = (0..batch_size) .map(|r| max_new_tokens.saturating_sub(emitted[r])) .collect(); - let (accepted_per_row, new_tokens_per_row) = speculative_walk_batched( - &draft_tokens_per_row, - &target_tokens_per_row, - &budgets, - ); + let (accepted_per_row, new_tokens_per_row) = + speculative_walk_batched(&draft_tokens_per_row, &target_tokens_per_row, &budgets); // Record per-row accept lens. Finished rows still get an // entry (consistent shape across rounds) but the value is @@ -557,14 +552,8 @@ impl DFlashBatchedGenerator { // tail-zeroing handles rows whose accept counts are below // max_accepted. if min_accepted + 1 < bs { - let accepted_i32: Vec = - accepted_per_row.iter().map(|&a| a as i32).collect(); - target.rollback_partial_batched( - caches, - &verify_out, - &accepted_i32, - bs as i32, - ); + let accepted_i32: Vec = accepted_per_row.iter().map(|&a| a as i32).collect(); + target.rollback_partial_batched(caches, &verify_out, &accepted_i32, bs as i32); } // Periodic memory cache clear. Mirrors upstream @@ -822,8 +811,10 @@ mod tests { _accepted: i32, _block_size: i32, ) { - panic!("batched synthetic target should not call rollback_partial; \ - use rollback_partial_batched"); + panic!( + "batched synthetic target should not call rollback_partial; \ + use rollback_partial_batched" + ); } fn rollback_partial_batched( @@ -846,14 +837,15 @@ mod tests { self.rollback_events.set(ev); } - fn concat_hidden_for_drafter( - &self, - verify_out: &Self::VerifyOut, - ) -> UniquePtr { + fn concat_hidden_for_drafter(&self, verify_out: &Self::VerifyOut) -> UniquePtr { ffi::slice( &verify_out.captured_hidden, &[0, 0, 0], - &[verify_out.batch_size, verify_out.verify_len, self.concat_hidden_dim], + &[ + verify_out.batch_size, + verify_out.verify_len, + self.concat_hidden_dim, + ], ) } @@ -983,21 +975,15 @@ mod tests { ) -> (DFlashBatchedRunOutput, Vec<(Vec, i32)>) { let target = SyntheticTarget::new(5 * 8, argmax_fn); // 3 caches, one per "layer" (count is incidental for synthetic test). - let mut caches: Vec = - (0..3).map(|_| SyntheticCache::default()).collect(); + let mut caches: Vec = (0..3).map(|_| SyntheticCache::default()).collect(); let drafter = SyntheticBatchedDrafter::new(propose_fn); let lm = EmbedOnlyLm; - let mut gen = DFlashBatchedGenerator::with_drafter( - Box::new(drafter), - SamplingConfig::greedy(), - ); + let mut gen = + DFlashBatchedGenerator::with_drafter(Box::new(drafter), SamplingConfig::greedy()); gen.block_size = block_size; - let first_hidden = ffi::zeros( - &[batch_size as i32, 1, 5 * 8], - crate::dtype::FLOAT32, - ); + let first_hidden = ffi::zeros(&[batch_size as i32, 1, 5 * 8], crate::dtype::FLOAT32); let out = gen .run_batched( @@ -1210,7 +1196,10 @@ mod tests { // Drafter variant 1: always mismatches (accept 0 every round). let propose_always_wrong = |bonus: &[i32], bs: usize| -> Vec> { - bonus.iter().map(|_| (1..bs as i32).map(|_| 0).collect()).collect() + bonus + .iter() + .map(|_| (1..bs as i32).map(|_| 0).collect()) + .collect() }; let (out, _) = run_synthetic_batched_round_loop( 4, diff --git a/src/lib/mlxcel-core/src/drafter/gemma4_assistant/layer.rs b/src/lib/mlxcel-core/src/drafter/gemma4_assistant/layer.rs index fc701c17..3d0a8aa0 100644 --- a/src/lib/mlxcel-core/src/drafter/gemma4_assistant/layer.rs +++ b/src/lib/mlxcel-core/src/drafter/gemma4_assistant/layer.rs @@ -237,9 +237,14 @@ impl DrafterAttention { offsets, Some(freqs), ), - (None, RopeOffset::Scalar(offset)) => { - ffi::fast_rope(&queries, self.rope_dims, false, self.rope_theta, 1.0, offset) - } + (None, RopeOffset::Scalar(offset)) => ffi::fast_rope( + &queries, + self.rope_dims, + false, + self.rope_theta, + 1.0, + offset, + ), (None, RopeOffset::PerRow(offsets)) => crate::fast_rope_batched( &queries, self.rope_dims, diff --git a/src/lib/mlxcel-core/src/drafter/gemma4_assistant/model.rs b/src/lib/mlxcel-core/src/drafter/gemma4_assistant/model.rs index 5ec7272f..6d08541b 100644 --- a/src/lib/mlxcel-core/src/drafter/gemma4_assistant/model.rs +++ b/src/lib/mlxcel-core/src/drafter/gemma4_assistant/model.rs @@ -195,11 +195,7 @@ impl OwnedSharedKv { let kv_valid_len = kv_len.saturating_sub(left_padding as i32); let valid_scalar = BatchScalar::Scalar(kv_valid_len); let left_scalar = BatchScalar::Scalar(left_padding as i32); - Self::from_shared_kv_normalized_with_metadata( - shared, - &valid_scalar, - Some(&left_scalar), - ) + Self::from_shared_kv_normalized_with_metadata(shared, &valid_scalar, Some(&left_scalar)) } /// Batched-MTP constructor that accepts explicit per-row valid lengths @@ -240,7 +236,9 @@ impl OwnedSharedKv { let normalized = normalize_batched_shared_kv_states(&map, kv_valid_len, left_padding); let take_pair = |layer: LayerType| -> Option<(UniquePtr, UniquePtr)> { - normalized.get(&layer).map(|(k, v)| (ffi::copy(k), ffi::copy(v))) + normalized + .get(&layer) + .map(|(k, v)| (ffi::copy(k), ffi::copy(v))) }; Ok(Self { @@ -404,7 +402,10 @@ impl std::fmt::Debug for Gemma4AssistantDraftModel { .field("backbone_hidden_size", &self.config.backbone_hidden_size) .field("block_size", &self.config.block_size) .field("tie_word_embeddings", &self.config.tie_word_embeddings) - .field("use_ordered_embeddings", &self.config.use_ordered_embeddings) + .field( + "use_ordered_embeddings", + &self.config.use_ordered_embeddings, + ) .field("num_layers", &self.inner.layers.len()) .field("centroid_lm_head_ready", &self.centroid_lm_head.is_some()) .field("bound", &self.lm_head.is_some()) @@ -561,16 +562,16 @@ impl Gemma4AssistantDraftModel { // can't reach the weights here, the centroid head is pre-built // during `from_weights` (via `centroid_lm_head: Option`) // and this method just takes it out. - let centroid = self - .centroid_lm_head - .take() - .ok_or_else(|| DrafterError::WeightLoad { - reason: "use_ordered_embeddings=true but MaskedEmbedder was not pre-built \ + let centroid = + self.centroid_lm_head + .take() + .ok_or_else(|| DrafterError::WeightLoad { + reason: "use_ordered_embeddings=true but MaskedEmbedder was not pre-built \ during from_weights; ensure the checkpoint contains \ masked_embedding.centroids.weight and \ masked_embedding.token_ordering" - .into(), - })?; + .into(), + })?; LmHead::Centroid(centroid) } else if self.config.tie_word_embeddings { LmHead::Tied @@ -769,9 +770,7 @@ impl Gemma4AssistantDraftModel { // Convert the string layer_type to the `LayerType` enum to look // up in the `HashMap` returned by `make_drafter_masks`. let layer_type_enum = str_to_layer_type(layer.layer_type())?; - let mask_opt = masks - .get(&layer_type_enum) - .and_then(|m| m.as_deref()); + let mask_opt = masks.get(&layer_type_enum).and_then(|m| m.as_deref()); h = layer.forward(&h, mask_opt, k, v, rope_offset); } @@ -981,8 +980,9 @@ impl Drafter for Gemma4AssistantDraftModel { } // Per-row token-stream accumulators. - let mut tokens_per_row: Vec> = - (0..batch_size).map(|_| Vec::with_capacity(proposals as usize)).collect(); + let mut tokens_per_row: Vec> = (0..batch_size) + .map(|_| Vec::with_capacity(proposals as usize)) + .collect(); // Per-step recurrent state: `h_prev` starts at the caller's // [B, 1, backbone] target hidden; `last_tokens` starts at the diff --git a/src/lib/mlxcel-core/src/drafter/gemma4_assistant/tests.rs b/src/lib/mlxcel-core/src/drafter/gemma4_assistant/tests.rs index 51cb3f7d..006b535c 100644 --- a/src/lib/mlxcel-core/src/drafter/gemma4_assistant/tests.rs +++ b/src/lib/mlxcel-core/src/drafter/gemma4_assistant/tests.rs @@ -185,11 +185,7 @@ fn make_test_weights(config: &Gemma4AssistantConfig) -> WeightMap { &format!("{p}.mlp.gate_proj.weight"), &[inter, hidden], ); - insert_zeros( - &mut w, - &format!("{p}.mlp.up_proj.weight"), - &[inter, hidden], - ); + insert_zeros(&mut w, &format!("{p}.mlp.up_proj.weight"), &[inter, hidden]); insert_zeros( &mut w, &format!("{p}.mlp.down_proj.weight"), @@ -261,10 +257,7 @@ impl LanguageModel for MockLanguageModel { let shape = ffi::array_shape(input_ids); let b = shape[0]; let l = shape[1]; - Some(ffi::zeros( - &[b, l, self.hidden_size], - crate::dtype::FLOAT32, - )) + Some(ffi::zeros(&[b, l, self.hidden_size], crate::dtype::FLOAT32)) } } @@ -304,7 +297,11 @@ fn from_weights_loads_tied_dense_drafter() { let cfg = make_test_config(2, true); let weights = make_test_weights(&cfg); let model = Gemma4AssistantDraftModel::from_weights(weights, cfg); - assert!(model.is_ok(), "tied-dense drafter must load: {:?}", model.err()); + assert!( + model.is_ok(), + "tied-dense drafter must load: {:?}", + model.err() + ); } #[test] @@ -365,7 +362,9 @@ fn draft_block_rejects_call_before_set_shared_kv() { model.bind(&target).expect("bind"); let sampler = SamplingConfig::greedy(); - let err = model.draft_block(0, None, 4, &sampler).expect_err("must fail"); + let err = model + .draft_block(0, None, 4, &sampler) + .expect_err("must fail"); match err { DrafterError::SetSharedKvNotCalled => {} other => panic!("expected SetSharedKvNotCalled, got {other:?}"), @@ -381,7 +380,9 @@ fn draft_block_rejects_call_before_bind() { // Skip bind, set_shared_kv first (the order doesn't matter; both // pre-conditions must hold). draft_block must fail before it runs. let sampler = SamplingConfig::greedy(); - let err = model.draft_block(0, None, 4, &sampler).expect_err("must fail"); + let err = model + .draft_block(0, None, 4, &sampler) + .expect_err("must fail"); // Order check: set_shared_kv runs first inside draft_block, so the // first guard to trigger is SetSharedKvNotCalled. After that fixes, // BindNotCalled fires. The test pins the first-failure ordering. @@ -477,7 +478,11 @@ fn centroid_path_bind_set_shared_kv_draft_block_end_to_end() { let tc = cfg.text_config(); let vocab = tc.vocab_size; let backbone = cfg.backbone_hidden_size as i32; - assert_eq!(vocab % cfg.num_centroids, 0, "vocab must be divisible by num_centroids"); + assert_eq!( + vocab % cfg.num_centroids, + 0, + "vocab must be divisible by num_centroids" + ); // Rebuild pre/post projection weights for the corrected backbone size. let mut weights = make_test_weights(&cfg); @@ -489,7 +494,9 @@ fn centroid_path_bind_set_shared_kv_draft_block_end_to_end() { let mut model = Gemma4AssistantDraftModel::from_weights(weights, cfg.clone()).expect("load"); let target = MockLanguageModel::new(vocab as i32, 64); - model.bind(&target).expect("bind must succeed for centroid path"); + model + .bind(&target) + .expect("bind must succeed for centroid path"); // Set up shared K/V: 4 tensors (full + SWA) at [B=1, n_kv=1, kv=4, head=32]. let kv_shape = &[1_i32, 1, 4, 32]; @@ -505,7 +512,9 @@ fn centroid_path_bind_set_shared_kv_draft_block_end_to_end() { ]; let shared = crate::drafter::SharedKv::new(&tensors); model - .set_shared_kv(shared, /*kv_offset=*/ 0, /*position=*/ 0, /*left_padding=*/ 0) + .set_shared_kv( + shared, /*kv_offset=*/ 0, /*position=*/ 0, /*left_padding=*/ 0, + ) .expect("set_shared_kv"); // Build a hidden tensor [1, 1, backbone]. @@ -608,22 +617,10 @@ fn draft_block_batched_returns_b_rows_of_k_minus_one_proposals() { let kv_len = 8; let head_dim = 32; let n_kv = 1; - let k_full = ffi::zeros( - &[batch_size, n_kv, kv_len, head_dim], - crate::dtype::FLOAT32, - ); - let v_full = ffi::zeros( - &[batch_size, n_kv, kv_len, head_dim], - crate::dtype::FLOAT32, - ); - let k_swa = ffi::zeros( - &[batch_size, n_kv, kv_len, head_dim], - crate::dtype::FLOAT32, - ); - let v_swa = ffi::zeros( - &[batch_size, n_kv, kv_len, head_dim], - crate::dtype::FLOAT32, - ); + let k_full = ffi::zeros(&[batch_size, n_kv, kv_len, head_dim], crate::dtype::FLOAT32); + let v_full = ffi::zeros(&[batch_size, n_kv, kv_len, head_dim], crate::dtype::FLOAT32); + let k_swa = ffi::zeros(&[batch_size, n_kv, kv_len, head_dim], crate::dtype::FLOAT32); + let v_swa = ffi::zeros(&[batch_size, n_kv, kv_len, head_dim], crate::dtype::FLOAT32); let tensors: Vec<&MlxArray> = vec![ k_full.as_ref().unwrap(), v_full.as_ref().unwrap(), @@ -631,9 +628,7 @@ fn draft_block_batched_returns_b_rows_of_k_minus_one_proposals() { v_swa.as_ref().unwrap(), ]; let shared = SharedKv::new(&tensors); - model - .set_shared_kv(shared, 0, 0, 0) - .expect("set_shared_kv"); + model.set_shared_kv(shared, 0, 0, 0).expect("set_shared_kv"); // hidden tensor: [B, 1, backbone]. let hidden = ffi::zeros(&[batch_size, 1, backbone], crate::dtype::FLOAT32); @@ -668,9 +663,7 @@ fn draft_block_batched_rejects_block_size_zero_with_empty_rows() { let v = ffi::zeros(&[2, 1, 4, 32], crate::dtype::FLOAT32); let tensors: Vec<&MlxArray> = vec![k.as_ref().unwrap(), v.as_ref().unwrap()]; let shared = SharedKv::new(&tensors); - model - .set_shared_kv(shared, 0, 0, 0) - .expect("set_shared_kv"); + model.set_shared_kv(shared, 0, 0, 0).expect("set_shared_kv"); let last_bonus = vec![1_i32, 2]; let sampler = SamplingConfig::greedy(); diff --git a/src/lib/mlxcel-core/src/drafter/masked_embedder.rs b/src/lib/mlxcel-core/src/drafter/masked_embedder.rs index dd09ee78..f0d6e2a3 100644 --- a/src/lib/mlxcel-core/src/drafter/masked_embedder.rs +++ b/src/lib/mlxcel-core/src/drafter/masked_embedder.rs @@ -229,12 +229,11 @@ impl MaskedEmbedder { } })?; let ordering_key = format!("{prefix}.token_ordering"); - let token_ordering = weights - .get(&ordering_key) - .map(|w| ffi::copy(w)) - .ok_or(MaskedEmbedderError::MissingWeight { + let token_ordering = weights.get(&ordering_key).map(|w| ffi::copy(w)).ok_or( + MaskedEmbedderError::MissingWeight { key: ordering_key.clone(), - })?; + }, + )?; Self::new( centroids, token_ordering, @@ -276,10 +275,7 @@ impl MaskedEmbedder { // Normalise to [B, L, H] for the rest of the pipeline so the // reshape arithmetic stays uniform. let hidden = if shape.len() == 2 { - ffi::reshape( - hidden_states, - &[batch, seq_len, self.hidden_size as i32], - ) + ffi::reshape(hidden_states, &[batch, seq_len, self.hidden_size as i32]) } else { ffi::copy(hidden_states) }; @@ -355,11 +351,7 @@ impl MaskedEmbedder { // full_f32 casts the f32 scalar to the requested dtype, so bf16/f16 // inputs produce a bf16/f16 scratch tensor and the scatter that // follows stays in the model's native precision (no f32 promotion). - let out = ffi::full_f32( - &[batch, seq_len, self.vocab_size as i32], - mask_value, - dtype, - ); + let out = ffi::full_f32(&[batch, seq_len, self.vocab_size as i32], mask_value, dtype); // Step 10-11: scatter selected logits into the masked tensor. // @@ -418,10 +410,7 @@ mod tests { // Row c = [10*c + 1, 0.0]: this puts centroid c's logit at // 10*c*h[0] for an input with h[1] = 0, so centroids are perfectly // ordered by their index when h[0] > 0. - let cw = ffi::from_slice_f32( - &[1.0, 0.0, 11.0, 0.0, 21.0, 0.0, 31.0, 0.0], - &[4, 2], - ); + let cw = ffi::from_slice_f32(&[1.0, 0.0, 11.0, 0.0, 21.0, 0.0, 31.0, 0.0], &[4, 2]); let centroids = Linear::new(cw, None); // token_ordering[c*4 + k] = c*4 + k → contiguous block per centroid. @@ -779,10 +768,7 @@ mod tests { // Dense-LM-head drafters (26B-A4B, 31B) carry no centroid table — // the hook must not insert anything in their absence. let mut weights: WeightMap = WeightMap::new(); - weights.insert( - "something.else".to_string(), - from_slice_i32(&[0, 1], &[2]), - ); + weights.insert("something.else".to_string(), from_slice_i32(&[0, 1], &[2])); sanitize_token_ordering(&mut weights, "masked_embedding"); diff --git a/src/lib/mlxcel-core/src/drafter/masks.rs b/src/lib/mlxcel-core/src/drafter/masks.rs index a8fddf03..5df26525 100644 --- a/src/lib/mlxcel-core/src/drafter/masks.rs +++ b/src/lib/mlxcel-core/src/drafter/masks.rs @@ -171,13 +171,7 @@ pub fn bidirectional_full_mask( dtype: i32, ) -> Option> { let key_offset = BatchScalar::Scalar(0); - bidirectional_full_mask_with_key_offset( - _query_len, - kv_len, - kv_valid_len, - &key_offset, - dtype, - ) + bidirectional_full_mask_with_key_offset(_query_len, kv_len, kv_valid_len, &key_offset, dtype) } /// Build the bidirectional full-attention bias with an absolute K/V @@ -328,10 +322,7 @@ pub fn bidirectional_swa_mask_with_key_offset( if let (BatchScalar::Scalar(qo), true, BatchScalar::Scalar(ko)) = (query_offset, kv_valid_is_scalar, key_offset) { - if kv_len <= window - && *qo - *ko < window - && *ko + kv_len - (*qo + query_len) < window - { + if kv_len <= window && *qo - *ko < window && *ko + kv_len - (*qo + query_len) < window { return None; } } @@ -382,8 +373,8 @@ pub fn bidirectional_swa_mask_with_key_offset( let q_range_1d = ffi::arange_i32(0, query_len, 1); // [query_len] let q_range = ffi::reshape(&q_range_1d, &[1, query_len]); // [1, query_len] let q_idx_2d = ffi::add(&qo_col, &q_range); // [B, query_len] - // Reshape for the [B, query_len, kv_len] computation: - // q_idx -> [B, query_len, 1], k_idx -> [1, 1, kv_len]. + // Reshape for the [B, query_len, kv_len] computation: + // q_idx -> [B, query_len, 1], k_idx -> [1, 1, kv_len]. let q_idx = ffi::reshape(&q_idx_2d, &[batch, query_len, 1]); let k_idx = match key_offset { @@ -404,7 +395,7 @@ pub fn bidirectional_swa_mask_with_key_offset( // Per-row kv_valid_len tail-mask. let inside = apply_kv_valid_tail_batched(inside, &k_idx, kv_valid_len, batch); let bias_3d = build_bias_from_bool(&inside, dtype); // [B, query_len, kv_len] - // Reshape to [B, 1, query_len, kv_len]. + // Reshape to [B, 1, query_len, kv_len]. Some(ffi::reshape(&bias_3d, &[batch, 1, query_len, kv_len])) } } @@ -573,12 +564,9 @@ pub fn make_drafter_masks_with_valid_len( &effective_valid, dtype, ), - LayerType::FullAttention => make_full_mask_with_absolute_key_offset( - query_len, - kv_len, - &effective_valid, - dtype, - ), + LayerType::FullAttention => { + make_full_mask_with_absolute_key_offset(query_len, kv_len, &effective_valid, dtype) + } }; masks.insert(layer_type, mask); } @@ -855,11 +843,7 @@ fn roll_left_per_row( /// Mirror upstream's `_broadcast_batch_vector`: lift `value` to an /// `int32` 1-D tensor of length `batch`, repeating B=1 inputs across rows, /// and clip into `[0, limit]`. -fn broadcast_batch_vector( - value: &BatchScalar<'_>, - batch: i32, - limit: i32, -) -> UniquePtr { +fn broadcast_batch_vector(value: &BatchScalar<'_>, batch: i32, limit: i32) -> UniquePtr { let raw = match value { BatchScalar::Scalar(v) => ffi::from_slice_i32(&[*v], &[1]), BatchScalar::PerRow(arr) => { @@ -966,7 +950,12 @@ mod tests { #[test] fn full_mask_returns_none_when_no_padding() { // kv_valid_len == None: degenerate case, mask is None. - let mask = bidirectional_full_mask(/*query_len=*/ 1, /*kv_len=*/ 8, None, dtype::FLOAT32); + let mask = bidirectional_full_mask( + /*query_len=*/ 1, + /*kv_len=*/ 8, + None, + dtype::FLOAT32, + ); assert!(mask.is_none(), "full mask must be None without padding"); } @@ -986,8 +975,13 @@ mod tests { fn full_mask_scalar_short_prefix_produces_expected_bias() { // kv_len=5, kv_valid_len=3: positions [0,1,2]=0, [3,4]=-inf. let valid = BatchScalar::Scalar(3); - let mask = bidirectional_full_mask(/*query_len=*/ 1, /*kv_len=*/ 5, Some(&valid), dtype::FLOAT32) - .expect("mask must materialise when valid < kv_len"); + let mask = bidirectional_full_mask( + /*query_len=*/ 1, + /*kv_len=*/ 5, + Some(&valid), + dtype::FLOAT32, + ) + .expect("mask must materialise when valid < kv_len"); assert_eq!(ffi::array_shape(&mask), vec![1, 1, 1, 5]); // Inspect each column. @@ -1043,9 +1037,14 @@ mod tests { // the valid prefix. let valid = BatchScalar::Scalar(10); let key_offset = BatchScalar::Scalar(7); - let mask = - bidirectional_full_mask_with_key_offset(1, 4, Some(&valid), &key_offset, dtype::FLOAT32) - .expect("absolute key 10 must be masked"); + let mask = bidirectional_full_mask_with_key_offset( + 1, + 4, + Some(&valid), + &key_offset, + dtype::FLOAT32, + ) + .expect("absolute key 10 must be masked"); assert_eq!(ffi::array_shape(&mask), vec![1, 1, 1, 4]); for k in 0..3 { @@ -1074,7 +1073,10 @@ mod tests { None, dtype::FLOAT32, ); - assert!(mask.is_none(), "SWA mask should be None when window dominates"); + assert!( + mask.is_none(), + "SWA mask should be None when window dominates" + ); } #[test] @@ -1113,8 +1115,7 @@ mod tests { // query_offset=3, query_len=1, kv_len=8, window=2. // q = 3, |3 - k| < 2 ⇒ k ∈ {2, 3, 4} -> 0; else -inf. let qo = BatchScalar::Scalar(3); - let mask = - bidirectional_swa_mask(1, &qo, 8, 2, None, dtype::FLOAT32).expect("materialise"); + let mask = bidirectional_swa_mask(1, &qo, 8, 2, None, dtype::FLOAT32).expect("materialise"); for k in 0..2 { let v = mask_at_qk(&mask, &[0, 0, 0, k]); assert!(v.is_infinite() && v < 0.0, "k={k} must be -inf"); @@ -1155,8 +1156,7 @@ mod tests { shared.insert(LayerType::FullAttention, (&k_full, &v_full)); shared.insert(LayerType::SlidingWindowAttention, (&k_swa, &v_swa)); - let masks = - make_drafter_masks(&shared, query_len, &qo, sliding_window, dtype::FLOAT32); + let masks = make_drafter_masks(&shared, query_len, &qo, sliding_window, dtype::FLOAT32); assert_eq!(masks.len(), 2); assert!( @@ -1164,7 +1164,10 @@ mod tests { "full mask must be None in the fast path", ); assert!( - masks.get(&LayerType::SlidingWindowAttention).unwrap().is_none(), + masks + .get(&LayerType::SlidingWindowAttention) + .unwrap() + .is_none(), "SWA mask must be None in the fast path", ); } @@ -1210,8 +1213,8 @@ mod tests { #[test] fn full_mask_dtype_matches_request_bfloat16() { let valid = BatchScalar::Scalar(3); - let mask = bidirectional_full_mask(1, 5, Some(&valid), dtype::BFLOAT16) - .expect("materialise"); + let mask = + bidirectional_full_mask(1, 5, Some(&valid), dtype::BFLOAT16).expect("materialise"); assert_eq!( ffi::array_dtype(&mask), dtype::BFLOAT16, @@ -1234,8 +1237,8 @@ mod tests { #[test] fn swa_mask_dtype_matches_request_bfloat16() { let qo = BatchScalar::Scalar(0); - let mask = bidirectional_swa_mask(1, &qo, 8, 4, None, dtype::BFLOAT16) - .expect("materialise"); + let mask = + bidirectional_swa_mask(1, &qo, 8, 4, None, dtype::BFLOAT16).expect("materialise"); assert_eq!( ffi::array_dtype(&mask), dtype::BFLOAT16, diff --git a/src/lib/mlxcel-core/src/drafter/mod.rs b/src/lib/mlxcel-core/src/drafter/mod.rs index 8a135312..a5743079 100644 --- a/src/lib/mlxcel-core/src/drafter/mod.rs +++ b/src/lib/mlxcel-core/src/drafter/mod.rs @@ -88,15 +88,15 @@ use std::fs; use std::path::Path; use std::sync::OnceLock; +pub mod dflash; +/// Concrete Gemma 4 MTP "assistant" drafter implementation. Wired into +/// [`load_drafter`]'s `Mtp` arm in issue #626. +pub mod gemma4_assistant; /// Centroid-routed sparse softmax LM head used by Gemma 4 E2B / E4B /// assistant drafters. Wired into `Gemma4AssistantDraftModel` in sub-3 /// (#626) — landed here independently per issue #627 so the layer can /// be unit-tested in isolation before integration. pub mod masked_embedder; -pub mod dflash; -/// Concrete Gemma 4 MTP "assistant" drafter implementation. Wired into -/// [`load_drafter`]'s `Mtp` arm in issue #626. -pub mod gemma4_assistant; /// Drafter shapes recognised by mlxcel. /// @@ -1026,9 +1026,8 @@ mod tests { // `Dflash` arm cannot silently regress to `NotYetImplemented`. let dir = tempdir().unwrap(); write_drafter_config(&dir, None); - let err = load_drafter(dir.path(), Some(DrafterKind::Dflash)).expect_err( - "load_drafter must fail on a config-only fixture with no safetensors", - ); + let err = load_drafter(dir.path(), Some(DrafterKind::Dflash)) + .expect_err("load_drafter must fail on a config-only fixture with no safetensors"); match err { DrafterError::LoadFailed { reason } => { // Reason is implementation-defined; the typed variant diff --git a/src/lib/mlxcel-core/src/rope_proportional.rs b/src/lib/mlxcel-core/src/rope_proportional.rs index f7124e9a..b2ff48eb 100644 --- a/src/lib/mlxcel-core/src/rope_proportional.rs +++ b/src/lib/mlxcel-core/src/rope_proportional.rs @@ -232,24 +232,12 @@ pub fn apply_proportional_rope_batched( return copy(x); } if batch == 1 { - return apply_proportional_rope( - x, - head_dim, - partial_rotary_factor, - offsets[0], - freqs, - ); + return apply_proportional_rope(x, head_dim, partial_rotary_factor, offsets[0], freqs); } let first_offset = offsets[0]; if offsets[1..].iter().all(|&offset| offset == first_offset) { - return apply_proportional_rope( - x, - head_dim, - partial_rotary_factor, - first_offset, - freqs, - ); + return apply_proportional_rope(x, head_dim, partial_rotary_factor, first_offset, freqs); } let rank = shape.len(); @@ -265,13 +253,7 @@ pub fn apply_proportional_rope_batched( begin[0] = batch_idx as i32; end[0] = batch_idx as i32 + 1; let chunk = slice(x, &begin, &end); - let chunk = apply_proportional_rope( - &chunk, - head_dim, - partial_rotary_factor, - offset, - freqs, - ); + let chunk = apply_proportional_rope(&chunk, head_dim, partial_rotary_factor, offset, freqs); result = concatenate(&result, &chunk, 0); } diff --git a/src/lib/mlxcel-core/src/speculative/mod.rs b/src/lib/mlxcel-core/src/speculative/mod.rs index 6e3cb1fb..2886840a 100644 --- a/src/lib/mlxcel-core/src/speculative/mod.rs +++ b/src/lib/mlxcel-core/src/speculative/mod.rs @@ -272,8 +272,7 @@ impl SpeculativeGenerator { let chunk_input = ffi::from_slice_i32(chunk, &[1, step as i32]); // Prefill both models with the chunk (logits discarded — we only // need the KV cache state for the continuation). - let _main_chunk_logits = - main_model.forward(&chunk_input, &mut self.main_caches, None); + let _main_chunk_logits = main_model.forward(&chunk_input, &mut self.main_caches, None); let _draft_chunk_logits = draft_model.forward(&chunk_input, &mut self.draft_caches, None); // Evaluate only the KV cache state so it is materialised before @@ -815,8 +814,7 @@ mod tests { #[test] fn prefill_step_size_matches_upstream_default() { assert_eq!( - PREFILL_STEP_SIZE, - 512, + PREFILL_STEP_SIZE, 512, "PREFILL_STEP_SIZE must match upstream mlx-lm default (512). \ Update this test if you intentionally deviate." ); diff --git a/src/lib/mlxcel-core/src/speculative/mtp/generator.rs b/src/lib/mlxcel-core/src/speculative/mtp/generator.rs index 35c82352..383725cd 100644 --- a/src/lib/mlxcel-core/src/speculative/mtp/generator.rs +++ b/src/lib/mlxcel-core/src/speculative/mtp/generator.rs @@ -302,7 +302,12 @@ impl MtpGenerator { return ( emitted, logprobs, - Self::build_stats(prompt_len, 0, std::time::Duration::ZERO, std::time::Duration::ZERO), + Self::build_stats( + prompt_len, + 0, + std::time::Duration::ZERO, + std::time::Duration::ZERO, + ), ); } @@ -331,7 +336,12 @@ impl MtpGenerator { return ( emitted, logprobs, - Self::build_stats(prompt_len, gen_count, prefill_time, std::time::Duration::ZERO), + Self::build_stats( + prompt_len, + gen_count, + prefill_time, + std::time::Duration::ZERO, + ), ); } @@ -417,9 +427,9 @@ impl MtpGenerator { // for the walk; the captured state holds hidden + shared // K/V slabs for the finalize step. let verify_forward_start = Instant::now(); - let forward_out = - self.target - .verify_forward(&verify_input, sampling, logprobs_config); + let forward_out = self + .target + .verify_forward(&verify_input, sampling, logprobs_config); diagnostics.verify_forward_ms += duration_ms(verify_forward_start.elapsed()); // Walk the draft against the target's argmax tokens. @@ -436,11 +446,9 @@ impl MtpGenerator { // *before* `forward_out` is moved into `verify_finalize`. let target_logprobs = forward_out.target_logprobs; let verify_finalize_start = Instant::now(); - verify_out = self.target.verify_finalize( - walk.accepted, - actual_bs, - forward_out.captured, - ); + verify_out = + self.target + .verify_finalize(walk.accepted, actual_bs, forward_out.captured); diagnostics.verify_finalize_ms += duration_ms(verify_finalize_start.elapsed()); // Emit accepted tokens. `walk.new_tokens[i] == target_tokens[i]` @@ -456,9 +464,7 @@ impl MtpGenerator { // `walk.new_tokens.len() <= actual_bs`, so `i` is // always in range — but a missing entry degrades to // `None` rather than panicking. - let lp = target_logprobs - .as_ref() - .and_then(|v| v.get(i).cloned()); + let lp = target_logprobs.as_ref().and_then(|v| v.get(i).cloned()); logprobs.push(lp); } if eos_tokens.contains(&tok) { diff --git a/src/lib/mlxcel-core/src/speculative/mtp/round_loop_batched.rs b/src/lib/mlxcel-core/src/speculative/mtp/round_loop_batched.rs index 30ca25de..c94475f6 100644 --- a/src/lib/mlxcel-core/src/speculative/mtp/round_loop_batched.rs +++ b/src/lib/mlxcel-core/src/speculative/mtp/round_loop_batched.rs @@ -225,21 +225,19 @@ impl MtpBatchedGenerator { for (r, prompt) in prompt_tokens_per_row.iter().enumerate() { if prompt.is_empty() { return Err(DrafterError::DraftFailed { - reason: format!( - "MTP batched round loop: prompt row {r} must be non-empty" - ), + reason: format!("MTP batched round loop: prompt row {r} must be non-empty"), }); } } - let eos_tokens = - merged_eos_token_ids(self.target.eos_token_ids(), &sampler.stop_token_ids); + let eos_tokens = merged_eos_token_ids(self.target.eos_token_ids(), &sampler.stop_token_ids); // Per-row output streams. The first slot holds the seed bonus the // target produced from `prefill_and_seed_batched`; the round loop // appends to these. - let mut tokens_per_row: Vec> = - (0..batch_size).map(|_| Vec::with_capacity(max_new_tokens)).collect(); + let mut tokens_per_row: Vec> = (0..batch_size) + .map(|_| Vec::with_capacity(max_new_tokens)) + .collect(); let mut accept_lens_per_row: Vec> = (0..batch_size).map(|_| Vec::new()).collect(); let mut finished: Vec = vec![false; batch_size]; @@ -451,11 +449,9 @@ impl MtpBatchedGenerator { // This is the per-row rollback contract from #655 (Gemma 4 // `rollback_speculative_cache`) and #657 (per-row mask // normalization). - verify_out = self.target.verify_finalize_batched( - &accepted_per_row, - bs, - forward_out.captured, - )?; + verify_out = + self.target + .verify_finalize_batched(&accepted_per_row, bs, forward_out.captured)?; if finished.iter().all(|&f| f) { break; @@ -847,12 +843,7 @@ mod tests { vec![60, 61, 62, 63], ], ]; - let target = BatchedMockTarget::new( - vec![100, 200], - script.clone(), - vec![], - vec![0, 0], - ); + let target = BatchedMockTarget::new(vec![100, 200], script.clone(), vec![], vec![0, 0]); // Drafter proposes the next-round's first bs-1 target tokens // every round → full accept. Track per-row round indices. let captured_script: Vec>> = script.clone(); @@ -911,12 +902,7 @@ mod tests { ], ]; let captured_script = script.clone(); - let target = BatchedMockTarget::new( - vec![100, 200], - script, - vec![], - vec![0, 0], - ); + let target = BatchedMockTarget::new(vec![100, 200], script, vec![], vec![0, 0]); let round_indices: RefCell> = RefCell::new(vec![0; 2]); let drafter = BatchedMockDrafter::new(move |bonus: &[i32], bs: usize| -> Vec> { // Row 0: full match. Row 1: all 999 (mismatch every position). @@ -994,12 +980,8 @@ mod tests { ], ]; let captured_script = script.clone(); - let target = BatchedMockTarget::new( - vec![100, 200, 300, 400], - script, - vec![99], - vec![0, 0, 0, 0], - ); + let target = + BatchedMockTarget::new(vec![100, 200, 300, 400], script, vec![99], vec![0, 0, 0, 0]); let round_indices: RefCell> = RefCell::new(vec![0; 4]); let drafter = BatchedMockDrafter::new(move |bonus: &[i32], bs: usize| -> Vec> { // Each row proposes the first bs-1 of its current round's @@ -1125,13 +1107,14 @@ mod tests { Option, ) { let seed = self.build_verify_output(1); - let first_bonus_lp = logprobs_config.enabled.then(|| { - crate::sampling::TokenLogprobData { - token_id: self.first_bonus, - logprob: 0.0, - top_alternatives: Vec::new(), - } - }); + let first_bonus_lp = + logprobs_config + .enabled + .then(|| crate::sampling::TokenLogprobData { + token_id: self.first_bonus, + logprob: 0.0, + top_alternatives: Vec::new(), + }); (self.first_bonus, seed, first_bonus_lp) } fn embed_token(&self, _token_id: i32) -> UniquePtr { @@ -1230,8 +1213,7 @@ mod tests { // Generate the per-row reference token streams by running // MtpGenerator (B = 1) on each row. - let mut reference_tokens: Vec> = - Vec::with_capacity(first_bonus_per_row.len()); + let mut reference_tokens: Vec> = Vec::with_capacity(first_bonus_per_row.len()); let scripts = make_script(first_bonus_per_row.len()); for (r, &first_bonus) in first_bonus_per_row.iter().enumerate() { let row_script = scripts[r].clone(); @@ -1314,8 +1296,10 @@ mod tests { reference_tokens[r].len(), "row {r}: token count must match reference" ); - for (i, (got, want)) in - batched_row.iter().zip(reference_tokens[r].iter()).enumerate() + for (i, (got, want)) in batched_row + .iter() + .zip(reference_tokens[r].iter()) + .enumerate() { assert_eq!( got, want, @@ -1344,12 +1328,7 @@ mod tests { ], ]; let captured_script = script.clone(); - let target = BatchedMockTarget::new( - vec![100, 200], - script, - vec![], - vec![0, 0], - ); + let target = BatchedMockTarget::new(vec![100, 200], script, vec![], vec![0, 0]); let round_indices: RefCell> = RefCell::new(vec![0; 2]); let drafter = BatchedMockDrafter::new(move |bonus: &[i32], bs: usize| -> Vec> { let mut round_idx = round_indices.borrow_mut(); @@ -1403,12 +1382,8 @@ mod tests { ]; let captured_script = script.clone(); let left_padding = vec![4_usize, 0, 2]; - let target = BatchedMockTarget::new( - vec![100, 200, 300], - script, - vec![], - left_padding.clone(), - ); + let target = + BatchedMockTarget::new(vec![100, 200, 300], script, vec![], left_padding.clone()); let drafter = BatchedMockDrafter::new(move |bonus: &[i32], bs: usize| -> Vec> { let mut out = Vec::with_capacity(bonus.len()); for row in captured_script.iter().take(bonus.len()) { @@ -1419,11 +1394,7 @@ mod tests { let mut gen_ = MtpBatchedGenerator::new(target, Box::new(drafter), 4); let out = gen_ .run_batched( - &[ - vec![1], - vec![1, 2, 3, 4, 5], - vec![1, 2, 3], - ], + &[vec![1], vec![1, 2, 3, 4, 5], vec![1, 2, 3]], &SamplingConfig::greedy(), 5, ) @@ -1455,17 +1426,9 @@ mod tests { // Two rows, K = 4, just enough rounds to surface the per-round // rebind: 1 seed bind + 1 mid-loop rebind after the first // verify_finalize_batched. - let script = vec![ - vec![vec![10, 11, 12, 13]], - vec![vec![20, 21, 22, 23]], - ]; + let script = vec![vec![vec![10, 11, 12, 13]], vec![vec![20, 21, 22, 23]]]; let captured_script = script.clone(); - let target = BatchedMockTarget::new( - vec![100, 200], - script, - vec![], - vec![0, 0], - ); + let target = BatchedMockTarget::new(vec![100, 200], script, vec![], vec![0, 0]); let drafter = BatchedMockDrafter::new(move |bonus: &[i32], bs: usize| -> Vec> { let mut out = Vec::with_capacity(bonus.len()); for row in captured_script.iter().take(bonus.len()) { @@ -1499,31 +1462,19 @@ mod tests { let target = BatchedMockTarget::new(vec![], vec![], vec![], vec![]); let drafter = BatchedMockDrafter::new(|_b: &[i32], _bs: usize| Vec::new()); let mut gen_ = MtpBatchedGenerator::new(target, Box::new(drafter), 4); - let result = gen_.run_batched( - &Vec::>::new(), - &SamplingConfig::greedy(), - 10, - ); + let result = gen_.run_batched(&Vec::>::new(), &SamplingConfig::greedy(), 10); assert!(result.is_err(), "empty batch must be rejected"); } /// max_new_tokens = 1 short-circuits to seed bonus only. #[test] fn batched_max_tokens_one_emits_only_seed_bonus() { - let target = BatchedMockTarget::new( - vec![100, 200], - vec![vec![], vec![]], - vec![], - vec![0, 0], - ); + let target = + BatchedMockTarget::new(vec![100, 200], vec![vec![], vec![]], vec![], vec![0, 0]); let drafter = BatchedMockDrafter::new(|_b: &[i32], _bs: usize| Vec::new()); let mut gen_ = MtpBatchedGenerator::new(target, Box::new(drafter), 4); let out = gen_ - .run_batched( - &[vec![1], vec![2]], - &SamplingConfig::greedy(), - 1, - ) + .run_batched(&[vec![1], vec![2]], &SamplingConfig::greedy(), 1) .expect("max_tokens = 1 must succeed"); assert_eq!(out.tokens[0], vec![100]); assert_eq!(out.tokens[1], vec![200]); @@ -1533,10 +1484,7 @@ mod tests { /// other continues normally. #[test] fn batched_seed_bonus_eos_freezes_only_that_row() { - let script = vec![ - vec![vec![10, 11, 12, 13]], - vec![vec![20, 21, 22, 23]], - ]; + let script = vec![vec![vec![10, 11, 12, 13]], vec![vec![20, 21, 22, 23]]]; let captured_script = script.clone(); let target = BatchedMockTarget::new( vec![100, 200], @@ -1553,11 +1501,7 @@ mod tests { }); let mut gen_ = MtpBatchedGenerator::new(target, Box::new(drafter), 4); let out = gen_ - .run_batched( - &[vec![1], vec![2]], - &SamplingConfig::greedy(), - 5, - ) + .run_batched(&[vec![1], vec![2]], &SamplingConfig::greedy(), 5) .expect("batched MTP must complete"); assert_eq!(out.tokens[0], vec![100], "row 0 froze on EOS seed"); assert_eq!(out.tokens[1].len(), 5, "row 1 continued normally"); @@ -1566,24 +1510,12 @@ mod tests { /// Sanity: drafter that returns wrong row-count fails fast. #[test] fn batched_round_loop_rejects_bad_drafter_row_count() { - let script = vec![ - vec![vec![10, 11, 12, 13]], - vec![vec![20, 21, 22, 23]], - ]; - let target = BatchedMockTarget::new( - vec![100, 200], - script, - vec![], - vec![0, 0], - ); + let script = vec![vec![vec![10, 11, 12, 13]], vec![vec![20, 21, 22, 23]]]; + let target = BatchedMockTarget::new(vec![100, 200], script, vec![], vec![0, 0]); // Bad drafter: always returns 1 row regardless of input batch. let drafter = BatchedMockDrafter::new(|_b: &[i32], bs: usize| vec![vec![0; bs - 1]]); let mut gen_ = MtpBatchedGenerator::new(target, Box::new(drafter), 4); - let result = gen_.run_batched( - &[vec![1], vec![2]], - &SamplingConfig::greedy(), - 5, - ); + let result = gen_.run_batched(&[vec![1], vec![2]], &SamplingConfig::greedy(), 5); assert!(result.is_err(), "bad drafter row count must surface"); } diff --git a/src/lib/mlxcel-core/src/speculative/mtp/target.rs b/src/lib/mlxcel-core/src/speculative/mtp/target.rs index c00eb2d9..04de431c 100644 --- a/src/lib/mlxcel-core/src/speculative/mtp/target.rs +++ b/src/lib/mlxcel-core/src/speculative/mtp/target.rs @@ -117,10 +117,7 @@ impl std::fmt::Debug for VerifyForwardOutput { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("VerifyForwardOutput") .field("target_tokens", &self.target_tokens) - .field( - "has_target_logprobs", - &self.target_logprobs.is_some(), - ) + .field("has_target_logprobs", &self.target_logprobs.is_some()) .field("captured", &self.captured) .finish() } @@ -188,7 +185,10 @@ impl MtpVerifyOutput { pub fn shared_kv_refs(&self) -> Vec<&MlxArray> { self.next_shared_kv .iter() - .map(|ptr| ptr.as_ref().expect("MtpVerifyOutput: non-null shared_kv ptr")) + .map(|ptr| { + ptr.as_ref() + .expect("MtpVerifyOutput: non-null shared_kv ptr") + }) .collect() } } diff --git a/src/lib/mlxcel-core/src/speculative/mtp/tests.rs b/src/lib/mlxcel-core/src/speculative/mtp/tests.rs index 8649dccc..f000c420 100644 --- a/src/lib/mlxcel-core/src/speculative/mtp/tests.rs +++ b/src/lib/mlxcel-core/src/speculative/mtp/tests.rs @@ -30,8 +30,8 @@ use super::walk::speculative_walk; use crate::drafter::{Drafter, DrafterError, DrafterKind, SharedKv}; use crate::ffi::{self, MlxArray}; use crate::generate::SamplingConfig; -use crate::weights::WeightMap; use crate::sampling::LogprobsConfig; +use crate::weights::WeightMap; use cxx::UniquePtr; use std::cell::RefCell; use std::sync::atomic::AtomicBool; @@ -134,13 +134,13 @@ impl MtpTarget for MockMtpTarget { let seed = self.build_verify_output(1); // Synthetic first-bonus logprob: `None` when logprobs are // disabled (the existing tests' path), a dummy entry otherwise. - let first_bonus_lp = logprobs_config.enabled.then(|| { - crate::sampling::TokenLogprobData { + let first_bonus_lp = logprobs_config + .enabled + .then(|| crate::sampling::TokenLogprobData { token_id: self.first_bonus, logprob: 0.0, top_alternatives: Vec::new(), - } - }); + }); (self.first_bonus, seed, first_bonus_lp) } @@ -338,14 +338,17 @@ fn round_loop_full_accept_emits_all_proposals_plus_bonus_per_round() { ], vec![], // no EOS ); - let drafter = MockMtpDrafter::new(vec![ - vec![10, 11, 12], - vec![20, 21, 22], - vec![30, 31, 32], - ]); + let drafter = MockMtpDrafter::new(vec![vec![10, 11, 12], vec![20, 21, 22], vec![30, 31, 32]]); let mut gen_ = MtpGenerator::new(target, Box::new(drafter), 4); - let (tokens, _logprobs, stats) = gen_.generate(&[1, 2, 3], 13, &SamplingConfig::greedy(), &[], &AtomicBool::new(false), &LogprobsConfig::default()); + let (tokens, _logprobs, stats) = gen_.generate( + &[1, 2, 3], + 13, + &SamplingConfig::greedy(), + &[], + &AtomicBool::new(false), + &LogprobsConfig::default(), + ); assert_eq!( tokens, vec![100, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33], @@ -381,24 +384,22 @@ fn round_loop_partial_accept_rolls_back_by_block_size_minus_accepted_minus_one() // rollback: 4 - 0 - 1 = 3 // // Total emitted: [100, 10, 11, 20] - let target = MockMtpTarget::new( - vec![ - vec![10, 11, 12, 13], - vec![20, 21, 22, 23], - ], - vec![], - ); - let drafter = MockMtpDrafter::new(vec![ - vec![10, 99, 12], - vec![99, 21, 22], - ]); + let target = MockMtpTarget::new(vec![vec![10, 11, 12, 13], vec![20, 21, 22, 23]], vec![]); + let drafter = MockMtpDrafter::new(vec![vec![10, 99, 12], vec![99, 21, 22]]); let mut gen_ = MtpGenerator::new(target, Box::new(drafter), 4); // max_tokens=4 caps emission at exactly the expected 4-token sequence // (seed + round-1's 2 accepted + round-2's 1 bonus). Larger // max_tokens would force more rounds against the reused scripted // entries, which is not what this test is measuring. - let (tokens, _logprobs, _stats) = gen_.generate(&[1, 2, 3], 4, &SamplingConfig::greedy(), &[], &AtomicBool::new(false), &LogprobsConfig::default()); + let (tokens, _logprobs, _stats) = gen_.generate( + &[1, 2, 3], + 4, + &SamplingConfig::greedy(), + &[], + &AtomicBool::new(false), + &LogprobsConfig::default(), + ); assert_eq!(tokens, vec![100, 10, 11, 20]); // Pin the rollback bookkeeping: `verify_finalize` must receive @@ -428,20 +429,18 @@ fn round_loop_respects_max_tokens_cap_on_full_accept() { // K=4, max_tokens=6. Seed bonus + 3 full-accept proposals would // emit 5 tokens; the second round adds 4 more but max_tokens=6 // caps at 6. Expected emission: [100, 10, 11, 12, 13, 20]. - let target = MockMtpTarget::new( - vec![ - vec![10, 11, 12, 13], - vec![20, 21, 22, 23], - ], - vec![], - ); - let drafter = MockMtpDrafter::new(vec![ - vec![10, 11, 12], - vec![20, 21, 22], - ]); + let target = MockMtpTarget::new(vec![vec![10, 11, 12, 13], vec![20, 21, 22, 23]], vec![]); + let drafter = MockMtpDrafter::new(vec![vec![10, 11, 12], vec![20, 21, 22]]); let mut gen_ = MtpGenerator::new(target, Box::new(drafter), 4); - let (tokens, _logprobs, stats) = gen_.generate(&[1, 2, 3], 6, &SamplingConfig::greedy(), &[], &AtomicBool::new(false), &LogprobsConfig::default()); + let (tokens, _logprobs, stats) = gen_.generate( + &[1, 2, 3], + 6, + &SamplingConfig::greedy(), + &[], + &AtomicBool::new(false), + &LogprobsConfig::default(), + ); assert_eq!(tokens.len(), 6, "must cap at max_tokens=6"); assert_eq!(tokens[0], 100, "seed bonus is first"); assert_eq!(stats.generated_tokens, 6); @@ -456,7 +455,14 @@ fn round_loop_stops_on_eos_token() { let drafter = MockMtpDrafter::new(vec![vec![10, 11, 12]]); let mut gen_ = MtpGenerator::new(target, Box::new(drafter), 4); - let (tokens, _logprobs, _stats) = gen_.generate(&[1], 20, &SamplingConfig::greedy(), &[], &AtomicBool::new(false), &LogprobsConfig::default()); + let (tokens, _logprobs, _stats) = gen_.generate( + &[1], + 20, + &SamplingConfig::greedy(), + &[], + &AtomicBool::new(false), + &LogprobsConfig::default(), + ); assert_eq!(tokens, vec![100, 10, 11, 12]); } @@ -468,7 +474,14 @@ fn round_loop_first_bonus_eos_short_circuits_seed() { let drafter = MockMtpDrafter::new(vec![]); let mut gen_ = MtpGenerator::new(target, Box::new(drafter), 4); - let (tokens, _logprobs, stats) = gen_.generate(&[1], 20, &SamplingConfig::greedy(), &[], &AtomicBool::new(false), &LogprobsConfig::default()); + let (tokens, _logprobs, stats) = gen_.generate( + &[1], + 20, + &SamplingConfig::greedy(), + &[], + &AtomicBool::new(false), + &LogprobsConfig::default(), + ); assert_eq!(tokens, vec![100]); assert_eq!(stats.generated_tokens, 1); } @@ -479,7 +492,14 @@ fn round_loop_max_tokens_one_emits_only_seed_bonus() { let drafter = MockMtpDrafter::new(vec![]); let mut gen_ = MtpGenerator::new(target, Box::new(drafter), 4); - let (tokens, _logprobs, _) = gen_.generate(&[1], 1, &SamplingConfig::greedy(), &[], &AtomicBool::new(false), &LogprobsConfig::default()); + let (tokens, _logprobs, _) = gen_.generate( + &[1], + 1, + &SamplingConfig::greedy(), + &[], + &AtomicBool::new(false), + &LogprobsConfig::default(), + ); assert_eq!(tokens, vec![100]); } @@ -496,20 +516,18 @@ fn round_loop_rebinds_drafter_after_each_round() { // After each verify, the round-loop must call `set_shared_kv` on // the drafter. This pins the rebind sequence: one seed call + one // per round. - let target = MockMtpTarget::new( - vec![ - vec![10, 11, 12, 13], - vec![20, 21, 22, 23], - ], - vec![], - ); - let drafter = MockMtpDrafter::new(vec![ - vec![10, 11, 12], - vec![20, 21, 22], - ]); + let target = MockMtpTarget::new(vec![vec![10, 11, 12, 13], vec![20, 21, 22, 23]], vec![]); + let drafter = MockMtpDrafter::new(vec![vec![10, 11, 12], vec![20, 21, 22]]); let mut gen_ = MtpGenerator::new(target, Box::new(drafter), 4); - let (_tokens, _logprobs, _stats) = gen_.generate(&[1], 9, &SamplingConfig::greedy(), &[], &AtomicBool::new(false), &LogprobsConfig::default()); + let (_tokens, _logprobs, _stats) = gen_.generate( + &[1], + 9, + &SamplingConfig::greedy(), + &[], + &AtomicBool::new(false), + &LogprobsConfig::default(), + ); // The drafter is owned by the generator behind `Box`, // so we cannot downcast to read the log. Instead we exercise the @@ -561,7 +579,12 @@ fn greedy_parity_perfect_drafter_matches_no_drafter_baseline_32_tokens() { let scripted_target: Vec> = (0..11) .map(|r| { let base = 1000 + r * 4; - vec![base as i32, (base + 1) as i32, (base + 2) as i32, (base + 3) as i32] + vec![ + base as i32, + (base + 1) as i32, + (base + 2) as i32, + (base + 3) as i32, + ] }) .collect(); let scripted_draft: Vec> = scripted_target @@ -576,7 +599,14 @@ fn greedy_parity_perfect_drafter_matches_no_drafter_baseline_32_tokens() { let drafter = MockMtpDrafter::new(scripted_draft); let mut gen_ = MtpGenerator::new(target, Box::new(drafter), 4); - let (mtp_tokens, _logprobs, _) = gen_.generate(&[1], max_tokens, &SamplingConfig::greedy(), &[], &AtomicBool::new(false), &LogprobsConfig::default()); + let (mtp_tokens, _logprobs, _) = gen_.generate( + &[1], + max_tokens, + &SamplingConfig::greedy(), + &[], + &AtomicBool::new(false), + &LogprobsConfig::default(), + ); let baseline_tokens = greedy_baseline(&scripted_target, first_bonus, max_tokens); assert_eq!( diff --git a/src/lib/mlxcel-core/src/speculative/mtp/walk.rs b/src/lib/mlxcel-core/src/speculative/mtp/walk.rs index ef88da33..aced0a95 100644 --- a/src/lib/mlxcel-core/src/speculative/mtp/walk.rs +++ b/src/lib/mlxcel-core/src/speculative/mtp/walk.rs @@ -92,11 +92,7 @@ pub struct WalkResult { /// /// Both invariants are debug-asserted. Release builds tolerate /// shape mismatches by short-circuiting to a safe accept count. -pub fn speculative_walk( - draft_tokens: &[i32], - target_tokens: &[i32], - budget: usize, -) -> WalkResult { +pub fn speculative_walk(draft_tokens: &[i32], target_tokens: &[i32], budget: usize) -> WalkResult { debug_assert_eq!( target_tokens.len(), draft_tokens.len() + 1, @@ -337,10 +333,7 @@ mod tests { let budgets = vec![100, 100]; let (accepted, new_tokens) = speculative_walk_batched(&draft, &target, &budgets); assert_eq!(accepted, vec![3, 3]); - assert_eq!( - new_tokens, - vec![vec![10, 11, 12, 13], vec![20, 21, 22, 23]] - ); + assert_eq!(new_tokens, vec![vec![10, 11, 12, 13], vec![20, 21, 22, 23]]); } #[test] @@ -348,11 +341,7 @@ mod tests { // Acceptance criterion: per-row accept counts diverge in the // same batch. Row 0 full-accepts, row 1 mismatches at index 1, // row 2 mismatches at index 0. - let draft = vec![ - vec![10, 11, 12], - vec![20, 99, 22], - vec![99, 31, 32], - ]; + let draft = vec![vec![10, 11, 12], vec![20, 99, 22], vec![99, 31, 32]]; let target = vec![ vec![10, 11, 12, 13], vec![20, 21, 22, 23], diff --git a/src/lib/mlxcel-core/src/weights.rs b/src/lib/mlxcel-core/src/weights.rs index 5d4ff70b..69f4c37e 100644 --- a/src/lib/mlxcel-core/src/weights.rs +++ b/src/lib/mlxcel-core/src/weights.rs @@ -50,11 +50,7 @@ pub type WeightMap = HashMap>; pub trait WeightTransform { /// Apply the transform to `weights`. Returns `Ok(())` on success or /// an error string describing why the transform could not be applied. - fn apply( - &self, - weights: &mut WeightMap, - cfg: &serde_json::Value, - ) -> Result<(), String>; + fn apply(&self, weights: &mut WeightMap, cfg: &serde_json::Value) -> Result<(), String>; } /// Parse a `model.safetensors.index.json` file and return the set of unique shard filenames. diff --git a/src/lib/mlxcel-surgery/src/config.rs b/src/lib/mlxcel-surgery/src/config.rs index f8aae6e9..7a4a1de7 100644 --- a/src/lib/mlxcel-surgery/src/config.rs +++ b/src/lib/mlxcel-surgery/src/config.rs @@ -70,9 +70,9 @@ use globset::{Glob, GlobMatcher}; use serde::Deserialize; use crate::ops::{InterpolateOp, ScaleOp}; -use crate::{SharedSurgeryOp, SurgeryError, SurgeryPipeline}; #[cfg(test)] use crate::WeightMap; +use crate::{SharedSurgeryOp, SurgeryError, SurgeryPipeline}; /// The only schema version this parser understands. Bump when the /// YAML shape changes in a way that is not backwards-compatible; the @@ -332,11 +332,9 @@ fn materialize_op( if source_key.is_empty() { return Err(spec_error(idx, "replace", "source_key must not be empty")); } - let resolved = - resolve_existing_source(idx, "replace", "source", &source, base_dir)?; - let op = crate::ops::ReplaceOp::new(&pattern, &source_key, resolved).map_err( - |e| spec_error(idx, "replace", &format!("{e}")), - )?; + let resolved = resolve_existing_source(idx, "replace", "source", &source, base_dir)?; + let op = crate::ops::ReplaceOp::new(&pattern, &source_key, resolved) + .map_err(|e| spec_error(idx, "replace", &format!("{e}")))?; Ok(Arc::new(op)) } OpSpec::Interpolate { diff --git a/src/lib/mlxcel-surgery/src/ops/add.rs b/src/lib/mlxcel-surgery/src/ops/add.rs index 58876571..b29439fe 100644 --- a/src/lib/mlxcel-surgery/src/ops/add.rs +++ b/src/lib/mlxcel-surgery/src/ops/add.rs @@ -68,10 +68,10 @@ use std::path::{Path, PathBuf}; use globset::{Glob, GlobMatcher}; use mlxcel_core::dtype as mlx_dtype; -use mlxcel_core::weights::{WeightMap, load_safetensors}; +use mlxcel_core::weights::{load_safetensors, WeightMap}; use mlxcel_core::{ - MlxArray, UniquePtr, add as mlx_add, array_dtype, array_shape, astype, copy as mlx_copy, - multiply_scalar, + add as mlx_add, array_dtype, array_shape, astype, copy as mlx_copy, multiply_scalar, MlxArray, + UniquePtr, }; use crate::{SurgeryError, SurgeryOp}; @@ -180,11 +180,7 @@ impl AddOp { } impl SurgeryOp for AddOp { - fn apply( - &self, - weights: &mut WeightMap, - _cfg: &serde_json::Value, - ) -> Result<(), SurgeryError> { + fn apply(&self, weights: &mut WeightMap, _cfg: &serde_json::Value) -> Result<(), SurgeryError> { // Snapshot matched keys up front so we can mutate the map // afterwards without iterating a borrowed view. `keys` is // also handy for the "zero matches" diagnostic below. diff --git a/src/lib/mlxcel-surgery/src/ops/add_apply_tests.rs b/src/lib/mlxcel-surgery/src/ops/add_apply_tests.rs index b47559c4..fe8acaad 100644 --- a/src/lib/mlxcel-surgery/src/ops/add_apply_tests.rs +++ b/src/lib/mlxcel-surgery/src/ops/add_apply_tests.rs @@ -28,9 +28,7 @@ use mlxcel_core::weights::WeightMap; use mlxcel_core::{array_dtype, astype}; use super::add::AddOp; -use super::add_test_helpers::{ - OwnedTensor, extract_f32, f32_tensor, mlx_f32, write_single_donor, -}; +use super::add_test_helpers::{extract_f32, f32_tensor, mlx_f32, write_single_donor, OwnedTensor}; use crate::SurgeryOp; #[test] @@ -49,12 +47,8 @@ fn applies_correct_value_with_default_alpha() { mlx_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2]), ); - let op = AddOp::new( - "model.layers.*.mlp.down_proj.weight", - &donor_path, - 1.0, - ) - .expect("construct"); + let op = + AddOp::new("model.layers.*.mlp.down_proj.weight", &donor_path, 1.0).expect("construct"); op.apply(&mut weights, &serde_json::Value::Null) .expect("apply must succeed"); @@ -81,12 +75,8 @@ fn applies_correct_value_with_alpha_half() { mlx_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2]), ); - let op = AddOp::new( - "model.layers.*.mlp.down_proj.weight", - &donor_path, - 0.5, - ) - .expect("construct"); + let op = + AddOp::new("model.layers.*.mlp.down_proj.weight", &donor_path, 0.5).expect("construct"); op.apply(&mut weights, &serde_json::Value::Null) .expect("apply must succeed"); @@ -116,12 +106,8 @@ fn applies_correct_value_with_negative_alpha() { mlx_f32(&[10.0, 10.0, 10.0, 10.0], &[2, 2]), ); - let op = AddOp::new( - "model.layers.*.mlp.down_proj.weight", - &donor_path, - -2.0, - ) - .expect("construct"); + let op = + AddOp::new("model.layers.*.mlp.down_proj.weight", &donor_path, -2.0).expect("construct"); op.apply(&mut weights, &serde_json::Value::Null) .expect("apply must succeed"); @@ -174,8 +160,7 @@ fn alpha_zero_still_diagnoses_zero_match_patterns() { mlx_f32(&[1.0; 4], &[2, 2]), ); - let op = AddOp::new("layer.does.not.exist.*", "/anywhere.safetensors", 0.0) - .expect("construct"); + let op = AddOp::new("layer.does.not.exist.*", "/anywhere.safetensors", 0.0).expect("construct"); let err = op .apply(&mut weights, &serde_json::Value::Null) @@ -214,12 +199,8 @@ fn applies_to_multiple_matched_keys() { mlx_f32(&[100.0; 4], &[2, 2]), ); - let op = AddOp::new( - "model.layers.*.mlp.down_proj.weight", - &donor_path, - 2.5, - ) - .expect("construct"); + let op = + AddOp::new("model.layers.*.mlp.down_proj.weight", &donor_path, 2.5).expect("construct"); op.apply(&mut weights, &serde_json::Value::Null) .expect("apply must succeed across layers"); @@ -253,17 +234,10 @@ fn donor_dtype_is_cast_to_base_dtype() { // Build a base f16 array via f32 → astype(f16). let base_f32 = mlx_f32(&[1.0, 1.0, 1.0, 1.0], &[2, 2]); let base_f16 = astype(&base_f32, mlx_dtype::FLOAT16); - weights.insert( - "model.layers.0.mlp.down_proj.weight".to_string(), - base_f16, - ); + weights.insert("model.layers.0.mlp.down_proj.weight".to_string(), base_f16); - let op = AddOp::new( - "model.layers.*.mlp.down_proj.weight", - &donor_path, - 1.0, - ) - .expect("construct"); + let op = + AddOp::new("model.layers.*.mlp.down_proj.weight", &donor_path, 1.0).expect("construct"); op.apply(&mut weights, &serde_json::Value::Null) .expect("dtype-cast donor must succeed"); diff --git a/src/lib/mlxcel-surgery/src/ops/add_test_helpers.rs b/src/lib/mlxcel-surgery/src/ops/add_test_helpers.rs index 7fd2eb1a..c93e3ce2 100644 --- a/src/lib/mlxcel-surgery/src/ops/add_test_helpers.rs +++ b/src/lib/mlxcel-surgery/src/ops/add_test_helpers.rs @@ -27,9 +27,9 @@ use std::path::Path; use mlxcel_core::dtype as mlx_dtype; -use mlxcel_core::{MlxArray, UniquePtr, array_to_raw_bytes, eval, from_bytes}; -use safetensors::View; +use mlxcel_core::{array_to_raw_bytes, eval, from_bytes, MlxArray, UniquePtr}; use safetensors::tensor::Dtype as SafeTensorDtype; +use safetensors::View; /// A `safetensors::View` impl over owned bytes — copied from the /// pattern used in `src/distributed/pipeline/partial_loading_adapter_tests.rs`. @@ -70,16 +70,11 @@ pub(crate) fn f32_tensor(values: &[f32], shape: &[usize]) -> OwnedTensor { /// Write a single-tensor safetensors file under `dir` named /// `donor.safetensors`. Returns the file path. -pub(crate) fn write_single_donor( - dir: &Path, - key: &str, - tensor: OwnedTensor, -) -> std::path::PathBuf { +pub(crate) fn write_single_donor(dir: &Path, key: &str, tensor: OwnedTensor) -> std::path::PathBuf { let mut tensors = std::collections::HashMap::new(); tensors.insert(key.to_string(), tensor); let path = dir.join("donor.safetensors"); - safetensors::serialize_to_file(&tensors, None, &path) - .expect("serialize donor safetensors"); + safetensors::serialize_to_file(&tensors, None, &path).expect("serialize donor safetensors"); path } diff --git a/src/lib/mlxcel-surgery/src/ops/add_tests.rs b/src/lib/mlxcel-surgery/src/ops/add_tests.rs index 3476bf76..1295ada5 100644 --- a/src/lib/mlxcel-surgery/src/ops/add_tests.rs +++ b/src/lib/mlxcel-surgery/src/ops/add_tests.rs @@ -51,11 +51,7 @@ fn new_rejects_bad_glob() { #[test] fn zero_match_pattern_errors() { let dir = tempfile::tempdir().expect("tempdir"); - let donor_path = write_single_donor( - dir.path(), - "unused.key", - f32_tensor(&[1.0, 1.0], &[2]), - ); + let donor_path = write_single_donor(dir.path(), "unused.key", f32_tensor(&[1.0, 1.0], &[2])); let mut weights = WeightMap::new(); weights.insert( @@ -116,12 +112,8 @@ fn missing_source_key_errors() { mlx_f32(&[1.0; 4], &[2, 2]), ); - let op = AddOp::new( - "model.layers.*.self_attn.q_proj.weight", - &donor_path, - 1.0, - ) - .expect("construct"); + let op = + AddOp::new("model.layers.*.self_attn.q_proj.weight", &donor_path, 1.0).expect("construct"); match op.apply(&mut weights, &serde_json::Value::Null) { Err(SurgeryError::TensorNotFound(key)) => { @@ -153,12 +145,8 @@ fn shape_mismatch_returns_structured_error() { mlx_f32(&[1.0; 4], &[2, 2]), ); - let op = AddOp::new( - "model.layers.*.self_attn.q_proj.weight", - &donor_path, - 1.0, - ) - .expect("construct"); + let op = + AddOp::new("model.layers.*.self_attn.q_proj.weight", &donor_path, 1.0).expect("construct"); match op.apply(&mut weights, &serde_json::Value::Null) { Err(SurgeryError::ShapeMismatch { @@ -195,12 +183,8 @@ fn quantized_packed_base_returns_focused_error() { from_slice_u32(&packed_bits, &[2, 2]), ); - let op = AddOp::new( - "model.layers.*.mlp.gate_proj.weight", - &donor_path, - 1.0, - ) - .expect("construct"); + let op = + AddOp::new("model.layers.*.mlp.gate_proj.weight", &donor_path, 1.0).expect("construct"); let err = op .apply(&mut weights, &serde_json::Value::Null) diff --git a/src/lib/mlxcel-surgery/src/ops/prune/tensor_ops.rs b/src/lib/mlxcel-surgery/src/ops/prune/tensor_ops.rs index 0af76ccd..612c8f85 100644 --- a/src/lib/mlxcel-surgery/src/ops/prune/tensor_ops.rs +++ b/src/lib/mlxcel-surgery/src/ops/prune/tensor_ops.rs @@ -41,10 +41,7 @@ use crate::{SurgeryError, WeightMap}; /// Replace `weights[key]` with a zero-filled tensor of the same shape /// and dtype. Caller has already verified the key exists. -pub(super) fn zero_tensor_inplace( - weights: &mut WeightMap, - key: &str, -) -> Result<(), SurgeryError> { +pub(super) fn zero_tensor_inplace(weights: &mut WeightMap, key: &str) -> Result<(), SurgeryError> { let arr = weights .get(key) .ok_or_else(|| SurgeryError::TensorNotFound(key.to_string()))?; diff --git a/src/lib/mlxcel-surgery/src/ops/prune/tests/attention_head.rs b/src/lib/mlxcel-surgery/src/ops/prune/tests/attention_head.rs index 6f429cfd..032860db 100644 --- a/src/lib/mlxcel-surgery/src/ops/prune/tests/attention_head.rs +++ b/src/lib/mlxcel-surgery/src/ops/prune/tests/attention_head.rs @@ -43,9 +43,7 @@ fn attention_head_prune_zeros_q_proj_row_slice() { let (shape, floats) = read_f32_2d(&weights, "model.layers.0.self_attn.q_proj.weight"); assert_eq!(shape, vec![32, 16]); - let row_sum = |r: usize| -> f32 { - floats[r * 16..(r + 1) * 16].iter().sum() - }; + let row_sum = |r: usize| -> f32 { floats[r * 16..(r + 1) * 16].iter().sum() }; // Head 0 rows [0..8): nonzero. for r in 0..8 { assert!(row_sum(r) > 0.0, "row {r} (head 0) should be nonzero"); @@ -202,7 +200,8 @@ fn attention_head_uses_text_config_for_vlm() { "num_hidden_layers": 1, } }); - op.apply(&mut weights, &cfg).expect("apply on VLM-style cfg"); + op.apply(&mut weights, &cfg) + .expect("apply on VLM-style cfg"); let (_, floats) = read_f32_2d( &weights, @@ -256,11 +255,7 @@ fn quantized_attention_head_zeros_weight_scales_biases() { for r in 8..16 { let start = r * row_bytes; for b in 0..row_bytes { - assert_eq!( - bytes[start + b], - 0, - "scales row {r} byte {b} must be zero" - ); + assert_eq!(bytes[start + b], 0, "scales row {r} byte {b} must be zero"); } } // Verify rows [0..8) are not zero (still ones). diff --git a/src/lib/mlxcel-surgery/src/ops/prune/tests/layer.rs b/src/lib/mlxcel-surgery/src/ops/prune/tests/layer.rs index 60c42e8c..ce19588e 100644 --- a/src/lib/mlxcel-surgery/src/ops/prune/tests/layer.rs +++ b/src/lib/mlxcel-surgery/src/ops/prune/tests/layer.rs @@ -50,11 +50,17 @@ fn layer_prune_zeros_only_requested_layer() { // Layer 0 q_proj is zero. let (_, floats) = read_f32_2d(&weights, "model.layers.0.self_attn.q_proj.weight"); - assert!(floats.iter().all(|&x| x == 0.0), "layer 0 q_proj must be zero"); + assert!( + floats.iter().all(|&x| x == 0.0), + "layer 0 q_proj must be zero" + ); // Layer 0 up_proj is zero. let (_, floats) = read_f32_2d(&weights, "model.layers.0.mlp.up_proj.weight"); - assert!(floats.iter().all(|&x| x == 0.0), "layer 0 up_proj must be zero"); + assert!( + floats.iter().all(|&x| x == 0.0), + "layer 0 up_proj must be zero" + ); // Layer 1 q_proj is unchanged (still ones). let (_, floats) = read_f32_2d(&weights, "model.layers.1.self_attn.q_proj.weight"); @@ -80,7 +86,9 @@ fn layer_prune_errors_on_out_of_range_id() { ); let op = PruneOp::new( "model.layers.*", - PruneSelector::Layer { layer_ids: vec![99] }, + PruneSelector::Layer { + layer_ids: vec![99], + }, ) .expect("compile"); let cfg = make_cfg(2, 2, 8, 16, 2); diff --git a/src/lib/mlxcel-surgery/src/ops/replace/mod.rs b/src/lib/mlxcel-surgery/src/ops/replace/mod.rs index bf2944b1..76c632f3 100644 --- a/src/lib/mlxcel-surgery/src/ops/replace/mod.rs +++ b/src/lib/mlxcel-surgery/src/ops/replace/mod.rs @@ -71,10 +71,10 @@ mod wildcard; use wildcard::WildcardPattern; -#[cfg(test)] -mod tests; #[cfg(test)] mod quant_tests; +#[cfg(test)] +mod tests; /// Concrete `SurgeryOp` implementing tensor substitution from an /// external donor safetensors file. @@ -160,11 +160,7 @@ impl ReplaceOp { } impl SurgeryOp for ReplaceOp { - fn apply( - &self, - weights: &mut WeightMap, - _cfg: &serde_json::Value, - ) -> Result<(), SurgeryError> { + fn apply(&self, weights: &mut WeightMap, _cfg: &serde_json::Value) -> Result<(), SurgeryError> { // 1) Collect base keys to replace. Sorting keeps the // operation deterministic across `HashMap` iteration // order so error messages are stable. @@ -234,15 +230,15 @@ impl SurgeryOp for ReplaceOp { // keeps memory bounded for multi-GB donor files when only // a handful of tensors are being replaced. let donor_keys_needed_capture = donor_keys_needed.clone(); - let donor_map: WeightMap = mlxcel_core::weights::load_safetensors_filtered( - donor_path_str, - move |name| donor_keys_needed_capture.contains(name), - ) - .map_err(|e| { - SurgeryError::Other(anyhow::anyhow!( - "replace: failed to load donor safetensors {donor_path_str}: {e}" - )) - })?; + let donor_map: WeightMap = + mlxcel_core::weights::load_safetensors_filtered(donor_path_str, move |name| { + donor_keys_needed_capture.contains(name) + }) + .map_err(|e| { + SurgeryError::Other(anyhow::anyhow!( + "replace: failed to load donor safetensors {donor_path_str}: {e}" + )) + })?; // 5) Verify every requested donor key was provided. We do // this before mutating the base map so that an error diff --git a/src/lib/mlxcel-surgery/src/ops/replace/quant_tests.rs b/src/lib/mlxcel-surgery/src/ops/replace/quant_tests.rs index d0ae1018..3932f0a6 100644 --- a/src/lib/mlxcel-surgery/src/ops/replace/quant_tests.rs +++ b/src/lib/mlxcel-surgery/src/ops/replace/quant_tests.rs @@ -207,12 +207,7 @@ fn non_quantized_base_with_no_siblings_does_not_require_donor_siblings() { ); // Note: NO `.scales` / `.biases` in the donor. let donor_path = write_quant_donor(dir.path(), donor_tensors); - let op = ReplaceOp::new( - "model.linear.weight", - "model.linear.weight", - donor_path, - ) - .unwrap(); + let op = ReplaceOp::new("model.linear.weight", "model.linear.weight", donor_path).unwrap(); let mut weights = WeightMap::new(); weights.insert( "model.linear.weight".to_string(), diff --git a/src/lib/mlxcel-surgery/src/ops/replace/tests.rs b/src/lib/mlxcel-surgery/src/ops/replace/tests.rs index 5dff23c9..c9973db6 100644 --- a/src/lib/mlxcel-surgery/src/ops/replace/tests.rs +++ b/src/lib/mlxcel-surgery/src/ops/replace/tests.rs @@ -251,15 +251,16 @@ fn apply_with_source_key_wildcard_substitution() { #[test] fn apply_shape_mismatch_errors_and_leaves_weights_untouched() { let dir = tempfile::tempdir().unwrap(); - let donor_path = write_donor_safetensors_f32( - dir.path(), - &[("model.x.weight", &[1.0, 2.0, 3.0], &[3])], - ); + let donor_path = + write_donor_safetensors_f32(dir.path(), &[("model.x.weight", &[1.0, 2.0, 3.0], &[3])]); let op = ReplaceOp::new("model.x.weight", "model.x.weight", donor_path).unwrap(); let mut weights = WeightMap::new(); // Base has shape [2], donor has shape [3] -> mismatch. - weights.insert("model.x.weight".to_string(), make_tensor(&[10.0, 20.0], &[2])); + weights.insert( + "model.x.weight".to_string(), + make_tensor(&[10.0, 20.0], &[2]), + ); let err = op .apply(&mut weights, &serde_json::Value::Null) @@ -319,8 +320,7 @@ fn apply_dtype_mismatch_errors() { #[test] fn apply_missing_source_key_returns_tensor_not_found() { let dir = tempfile::tempdir().unwrap(); - let donor_path = - write_donor_safetensors_f32(dir.path(), &[("some.other.key", &[1.0], &[1])]); + let donor_path = write_donor_safetensors_f32(dir.path(), &[("some.other.key", &[1.0], &[1])]); let op = ReplaceOp::new("model.x.weight", "model.x.weight", donor_path).unwrap(); let mut weights = WeightMap::new(); @@ -384,4 +384,3 @@ fn name_is_replace() { let op = ReplaceOp::new("x", "x", donor).unwrap(); assert_eq!(op.name(), "replace"); } - diff --git a/src/lib/mlxcel-surgery/src/ops/replace/wildcard.rs b/src/lib/mlxcel-surgery/src/ops/replace/wildcard.rs index 3497d076..cdf72b9a 100644 --- a/src/lib/mlxcel-surgery/src/ops/replace/wildcard.rs +++ b/src/lib/mlxcel-surgery/src/ops/replace/wildcard.rs @@ -222,9 +222,7 @@ mod tests { .match_with_captures("model.layers.0.self_attn.weight") .unwrap(); assert_eq!(caps, vec!["0".to_string()]); - assert!(p - .match_with_captures("model.layers.0.mlp.weight") - .is_none()); + assert!(p.match_with_captures("model.layers.0.mlp.weight").is_none()); } #[test] @@ -246,7 +244,9 @@ mod tests { #[test] fn trailing_star_captures_suffix() { let p = WildcardPattern::parse("model.layers.0.*"); - let caps = p.match_with_captures("model.layers.0.q_proj.weight").unwrap(); + let caps = p + .match_with_captures("model.layers.0.q_proj.weight") + .unwrap(); assert_eq!(caps, vec!["q_proj.weight".to_string()]); } diff --git a/src/lib/mlxcel-surgery/src/ops/scale.rs b/src/lib/mlxcel-surgery/src/ops/scale.rs index fe2395c7..10882911 100644 --- a/src/lib/mlxcel-surgery/src/ops/scale.rs +++ b/src/lib/mlxcel-surgery/src/ops/scale.rs @@ -140,11 +140,7 @@ impl SurgeryOp for ScaleOp { /// the configured scalar. /// /// See the module docstring for the quantized-layout routing rule. - fn apply( - &self, - weights: &mut WeightMap, - _cfg: &serde_json::Value, - ) -> Result<(), SurgeryError> { + fn apply(&self, weights: &mut WeightMap, _cfg: &serde_json::Value) -> Result<(), SurgeryError> { // Stage 1 — resolve every glob hit on the *as-loaded* key set // to one or more *effective* mutation targets. This isolates // the iteration from the subsequent mutation, lets us diff --git a/src/lib/mlxcel-surgery/src/ops/scale_tests.rs b/src/lib/mlxcel-surgery/src/ops/scale_tests.rs index 52e22769..60f8e560 100644 --- a/src/lib/mlxcel-surgery/src/ops/scale_tests.rs +++ b/src/lib/mlxcel-surgery/src/ops/scale_tests.rs @@ -194,8 +194,7 @@ fn scales_matched_tensors_and_leaves_others_alone() { #[test] fn glob_wildcard_matches_multiple_layers() { - let op = ScaleOp::new("model.layers.*.self_attn.o_proj.weight", 0.5) - .expect("construct"); + let op = ScaleOp::new("model.layers.*.self_attn.o_proj.weight", 0.5).expect("construct"); let mut weights = WeightMap::new(); for layer in 0..3 { weights.insert( @@ -209,7 +208,8 @@ fn glob_wildcard_matches_multiple_layers() { ); } - op.apply(&mut weights, &serde_json::Value::Null).expect("apply"); + op.apply(&mut weights, &serde_json::Value::Null) + .expect("apply"); for layer in 0..3 { let o = weights @@ -232,7 +232,8 @@ fn preserves_dtype_and_shape_for_f16_tensor() { let mut weights = WeightMap::new(); weights.insert("model.layer.weight".to_string(), f16_arr); - op.apply(&mut weights, &serde_json::Value::Null).expect("apply"); + op.apply(&mut weights, &serde_json::Value::Null) + .expect("apply"); let after = weights.get("model.layer.weight").unwrap(); assert_eq!( @@ -270,12 +271,10 @@ fn quantized_affine_routes_to_scales_and_biases_not_packed_codes() { "model.layer.scales".to_string(), f32_tensor(&[0.25, 0.5, 1.0, 2.0]), ); - weights.insert( - "model.layer.biases".to_string(), - f32_tensor(&[10.0, -20.0]), - ); + weights.insert("model.layer.biases".to_string(), f32_tensor(&[10.0, -20.0])); - op.apply(&mut weights, &serde_json::Value::Null).expect("apply"); + op.apply(&mut weights, &serde_json::Value::Null) + .expect("apply"); // Packed codes — bit-identical. let packed = weights.get("model.layer.weight").unwrap(); @@ -315,7 +314,8 @@ fn quantized_mxfp4_layer_without_biases_scales_only_scales() { f32_tensor(&[2.0, -4.0, 8.0]), ); - op.apply(&mut weights, &serde_json::Value::Null).expect("apply"); + op.apply(&mut weights, &serde_json::Value::Null) + .expect("apply"); let packed = weights.get("model.embed.weight").unwrap(); mlxcel_core::eval(packed); @@ -343,7 +343,8 @@ fn pattern_matching_scales_directly_is_pass_through() { f32_tensor(&[0.1, 0.2, 0.3]), ); - op.apply(&mut weights, &serde_json::Value::Null).expect("apply"); + op.apply(&mut weights, &serde_json::Value::Null) + .expect("apply"); let after = weights.get("model.layer.scales").unwrap(); // Note: f32 multiply has rounding; assert with tolerance. @@ -374,7 +375,8 @@ fn wildcard_matching_quantized_triplet_does_not_double_scale() { weights.insert("model.layer.scales".to_string(), f32_tensor(&[3.0])); weights.insert("model.layer.biases".to_string(), f32_tensor(&[-5.0])); - op.apply(&mut weights, &serde_json::Value::Null).expect("apply"); + op.apply(&mut weights, &serde_json::Value::Null) + .expect("apply"); let scales = weights.get("model.layer.scales").unwrap(); assert_eq!( @@ -415,7 +417,8 @@ fn factor_one_is_a_value_preserving_pass() { f32_tensor(&[1.5, -2.5, 0.0, f32::MIN_POSITIVE]), ); - op.apply(&mut weights, &serde_json::Value::Null).expect("apply"); + op.apply(&mut weights, &serde_json::Value::Null) + .expect("apply"); let after = weights.get("model.layer.weight").unwrap(); let values = read_f32(after); diff --git a/src/lib/mlxcel-surgery/tests/add_integration.rs b/src/lib/mlxcel-surgery/tests/add_integration.rs index 346dcafd..da72af76 100644 --- a/src/lib/mlxcel-surgery/tests/add_integration.rs +++ b/src/lib/mlxcel-surgery/tests/add_integration.rs @@ -32,10 +32,10 @@ use std::path::Path; use mlxcel_core::dtype as mlx_dtype; use mlxcel_core::weights::{WeightMap, WeightTransform}; -use mlxcel_core::{MlxArray, UniquePtr, array_to_raw_bytes, eval, from_bytes}; -use mlxcel_surgery::{SurgeryPipeline, parse_config_file}; -use safetensors::View; +use mlxcel_core::{array_to_raw_bytes, eval, from_bytes, MlxArray, UniquePtr}; +use mlxcel_surgery::{parse_config_file, SurgeryPipeline}; use safetensors::tensor::Dtype as SafeTensorDtype; +use safetensors::View; /// `safetensors::View` over owned bytes — same shape as the helper /// in `src/lib/mlxcel-surgery/src/ops/add.rs`. Duplicated here @@ -123,10 +123,7 @@ fn yaml_driven_pipeline_applies_add_op_end_to_end() { for layer in 0..3 { donor.insert( format!("model.layers.{layer}.mlp.down_proj.weight"), - f32_tensor( - &[0.1 * (layer + 1) as f32; 8], - &[2, 4], - ), + f32_tensor(&[0.1 * (layer + 1) as f32; 8], &[2, 4]), ); } // Extra key that exists in the donor but should be irrelevant @@ -214,12 +211,8 @@ fn programmatic_pipeline_applies_add_op_end_to_end() { let mut pipeline = SurgeryPipeline::new(); pipeline.push(Arc::new( - mlxcel_surgery::AddOp::new( - "model.layers.*.self_attn.q_proj.weight", - &donor_path, - 1.0, - ) - .expect("construct AddOp"), + mlxcel_surgery::AddOp::new("model.layers.*.self_attn.q_proj.weight", &donor_path, 1.0) + .expect("construct AddOp"), )); let mut weights = WeightMap::new(); @@ -236,10 +229,7 @@ fn programmatic_pipeline_applies_add_op_end_to_end() { .get("model.layers.0.self_attn.q_proj.weight") .expect("present"); let actual = extract_f32(arr); - assert_eq!( - actual, - vec![11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0] - ); + assert_eq!(actual, vec![11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0]); } /// Acceptance criterion (e): without the op in the pipeline, an @@ -254,7 +244,10 @@ fn empty_pipeline_is_bit_exact_no_op() { let original: Vec = (0..16).map(|i| i as f32 * 0.5).collect(); let mut weights = WeightMap::new(); - weights.insert("model.embed_tokens.weight".to_string(), mlx_f32(&original, &[4, 4])); + weights.insert( + "model.embed_tokens.weight".to_string(), + mlx_f32(&original, &[4, 4]), + ); pipeline .apply(&mut weights, &serde_json::Value::Null) diff --git a/src/lib/mlxcel-surgery/tests/prune_integration.rs b/src/lib/mlxcel-surgery/tests/prune_integration.rs index f2648e47..4a17f4f1 100644 --- a/src/lib/mlxcel-surgery/tests/prune_integration.rs +++ b/src/lib/mlxcel-surgery/tests/prune_integration.rs @@ -27,7 +27,7 @@ use std::io::Write; use mlxcel_core::dtype; use mlxcel_core::weights::{WeightMap, WeightTransform}; use mlxcel_surgery::{ - PruneOp, PruneSelector, SurgeryPipeline, parse_config_file, parse_config_str, + parse_config_file, parse_config_str, PruneOp, PruneSelector, SurgeryPipeline, }; /// Build a Llama-style synthetic weight map for one layer with the @@ -83,7 +83,9 @@ operations: let pipeline = parse_config_str(yaml, None).expect("parse YAML"); assert_eq!(pipeline.len(), 1); - let mut weights = build_synthetic_weights(/* heads */ 4, /* head_dim */ 8, /* hidden */ 32); + let mut weights = build_synthetic_weights( + /* heads */ 4, /* head_dim */ 8, /* hidden */ 32, + ); let cfg = serde_json::json!({ "num_attention_heads": 4, "num_key_value_heads": 4, @@ -300,7 +302,10 @@ fn baseline_with_empty_pipeline_is_bit_exact_no_op() { mlxcel_core::eval(v); mlxcel_core::array_to_raw_bytes(v) }; - assert_eq!(before, &after, "key {k} must be byte-exact after empty pipeline"); + assert_eq!( + before, &after, + "key {k} must be byte-exact after empty pipeline" + ); } } diff --git a/src/lib/mlxcel-surgery/tests/replace_integration.rs b/src/lib/mlxcel-surgery/tests/replace_integration.rs index f1747312..a3111af1 100644 --- a/src/lib/mlxcel-surgery/tests/replace_integration.rs +++ b/src/lib/mlxcel-surgery/tests/replace_integration.rs @@ -29,7 +29,7 @@ //! `SurgeryPipeline` builder, `ReplaceOp` produces a `WeightMap` //! with the targeted tensor replaced by the donor's tensor. -use mlxcel_surgery::{WeightMap, WeightTransform, parse_config_file}; +use mlxcel_surgery::{parse_config_file, WeightMap, WeightTransform}; use std::collections::HashMap; use std::path::Path; @@ -150,10 +150,8 @@ operations: mlxcel_core::from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2]), ); - let err = - WeightTransform::apply(&pipeline, &mut weights, &serde_json::Value::Null).expect_err( - "zero-match pattern must error end-to-end through the pipeline", - ); + let err = WeightTransform::apply(&pipeline, &mut weights, &serde_json::Value::Null) + .expect_err("zero-match pattern must error end-to-end through the pipeline"); assert!( err.contains("matched no keys") || err.contains("replace"), "error must surface through SurgeryPipeline -> WeightTransform: {err}" diff --git a/src/lib/mlxcel-surgery/tests/scale_integration.rs b/src/lib/mlxcel-surgery/tests/scale_integration.rs index 32bd4534..1f4d3f62 100644 --- a/src/lib/mlxcel-surgery/tests/scale_integration.rs +++ b/src/lib/mlxcel-surgery/tests/scale_integration.rs @@ -85,15 +85,16 @@ operations: let values = read_f32(scaled); assert_eq!(values.len(), 4); for (got, want) in values.iter().zip([3.0_f32, 6.0, -9.0, 0.0].iter()) { - assert!( - (got - want).abs() < 1e-5, - "expected {want}, got {got}", - ); + assert!((got - want).abs() < 1e-5, "expected {want}, got {got}",); } // Untouched siblings — bit-identical. assert_eq!( - read_f32(weights.get("model.layers.0.self_attn.q_proj.weight").unwrap()), + read_f32( + weights + .get("model.layers.0.self_attn.q_proj.weight") + .unwrap() + ), vec![100.0, 200.0], ); assert_eq!( diff --git a/src/loading/vlm.rs b/src/loading/vlm.rs index 97d6fb4b..cb735c1d 100644 --- a/src/loading/vlm.rs +++ b/src/loading/vlm.rs @@ -181,11 +181,11 @@ pub(crate) fn load_vlm_weights_common( None => { #[cfg(feature = "surgery")] { - active_pipeline.as_deref().map( - |p: &crate::surgery::SurgeryPipeline| { + active_pipeline + .as_deref() + .map(|p: &crate::surgery::SurgeryPipeline| { p as &dyn mlxcel_core::weights::WeightTransform - }, - ) + }) } #[cfg(not(feature = "surgery"))] { diff --git a/src/models/sanitize.rs b/src/models/sanitize.rs index ff32ee52..7173bca9 100644 --- a/src/models/sanitize.rs +++ b/src/models/sanitize.rs @@ -1065,11 +1065,11 @@ pub fn load_text_weights>( None => { #[cfg(feature = "surgery")] { - active_pipeline.as_deref().map( - |p: &crate::surgery::SurgeryPipeline| { + active_pipeline + .as_deref() + .map(|p: &crate::surgery::SurgeryPipeline| { p as &dyn mlxcel_core::weights::WeightTransform - }, - ) + }) } #[cfg(not(feature = "surgery"))] { diff --git a/src/models/sanitize_tests.rs b/src/models/sanitize_tests.rs index ba9583ac..00c25b87 100644 --- a/src/models/sanitize_tests.rs +++ b/src/models/sanitize_tests.rs @@ -837,11 +837,10 @@ operations: "#, ) .unwrap(); - let pipeline = - mlxcel_surgery::parse_config_file(&yaml_path).expect("scale parses"); + let pipeline = mlxcel_surgery::parse_config_file(&yaml_path).expect("scale parses"); - let weights = load_text_weights(&dir, Some(&pipeline)) - .expect("scale through loader must succeed"); + let weights = + load_text_weights(&dir, Some(&pipeline)).expect("scale through loader must succeed"); let scaled = weights .get("model.layers.0.self_attn.q_proj.weight") @@ -856,8 +855,7 @@ operations: // None-transform path is the property under test here. let baseline_dir = temp_model_dir("yaml_scale_real_baseline"); write_text_model_fixture(&baseline_dir); - let baseline = load_text_weights(&baseline_dir, None) - .expect("baseline load must succeed"); + let baseline = load_text_weights(&baseline_dir, None).expect("baseline load must succeed"); let baseline_q = baseline .get("model.layers.0.self_attn.q_proj.weight") .unwrap(); @@ -892,8 +890,7 @@ operations: "#, ) .unwrap(); - let pipeline = - mlxcel_surgery::parse_config_file(&yaml_path).expect("scale parses"); + let pipeline = mlxcel_surgery::parse_config_file(&yaml_path).expect("scale parses"); let result = load_text_weights(&dir, Some(&pipeline)); match result { @@ -963,8 +960,7 @@ operations: pattern: "no.such.key.anywhere" factor: 2.0 "#; - let pipeline = mlxcel_surgery::parse_config_str(yaml, None) - .expect("scale yaml must parse"); + let pipeline = mlxcel_surgery::parse_config_str(yaml, None).expect("scale yaml must parse"); let _slot_guard = ScopedActivePipeline::install(Arc::new(pipeline)); let dir_with_slot = temp_model_dir("active_slot_installed"); @@ -1218,8 +1214,7 @@ operations: pattern: "model.layers.0.self_attn.*" head_ids: [2] "#; - let pipeline = - mlxcel_surgery::parse_config_str(yaml, None).expect("YAML must parse"); + let pipeline = mlxcel_surgery::parse_config_str(yaml, None).expect("YAML must parse"); assert_eq!(pipeline.len(), 1); let dir = temp_model_dir("prune_e2e_yaml"); @@ -1293,8 +1288,7 @@ operations: source_key: "model.embed_tokens.weight" "#; std::fs::write(&yaml_path, yaml).unwrap(); - let pipeline = - mlxcel_surgery::parse_config_file(&yaml_path).expect("replace yaml parses"); + let pipeline = mlxcel_surgery::parse_config_file(&yaml_path).expect("replace yaml parses"); let baseline = load_text_weights(&dir, None).unwrap(); let with_replace = load_text_weights(&dir, Some(&pipeline)).unwrap(); diff --git a/src/server/startup.rs b/src/server/startup.rs index f588f3a7..b1b8ca40 100644 --- a/src/server/startup.rs +++ b/src/server/startup.rs @@ -1225,10 +1225,7 @@ fn install_surgery_pipeline_for_server(startup: &ServerStartupConfig) -> Result< return Ok(()); }; if !path.exists() { - anyhow::bail!( - "--surgery: config file does not exist: {}", - path.display() - ); + anyhow::bail!("--surgery: config file does not exist: {}", path.display()); } let pipeline = crate::surgery::load_pipeline_from_file(path) .map_err(|e| anyhow::anyhow!("--surgery: {e}"))?; diff --git a/src/surgery.rs b/src/surgery.rs index 913e55f8..03708458 100644 --- a/src/surgery.rs +++ b/src/surgery.rs @@ -200,8 +200,8 @@ mod tests { #[test] fn load_pipeline_surfaces_io_error_for_missing_file() { let _guard = env_lock(); - let err = load_pipeline_from_file("/does/not/exist.yaml") - .expect_err("missing file must fail"); + let err = + load_pipeline_from_file("/does/not/exist.yaml").expect_err("missing file must fail"); assert!( err.contains("/does/not/exist.yaml"), "error must mention path: {err}" diff --git a/tests/surgery_cli.rs b/tests/surgery_cli.rs index 1f01fc93..a800ea38 100644 --- a/tests/surgery_cli.rs +++ b/tests/surgery_cli.rs @@ -206,9 +206,7 @@ fn extract_generated_suffix(stdout: &str) -> &str { // followed by a blank line and then the stats line. let start = stdout.find("Generating...\nHello").unwrap_or(0); let from_hello = &stdout[start..]; - let end = from_hello - .find("\n[Generated") - .unwrap_or(from_hello.len()); + let end = from_hello.find("\n[Generated").unwrap_or(from_hello.len()); &from_hello[..end] } @@ -243,11 +241,7 @@ fn malformed_surgery_yaml_fails_fast() { output.status ); let stderr = String::from_utf8_lossy(&output.stderr); - let combined = format!( - "{}{}", - String::from_utf8_lossy(&output.stdout), - stderr - ); + let combined = format!("{}{}", String::from_utf8_lossy(&output.stdout), stderr); assert!( combined.contains("surgery"), "error must mention 'surgery' so the user identifies the flag: {combined}"