Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lib/mlxcel-core/src/cache/paged_detach.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ impl Drop for DetachedPagedCacheSet {

/// Internal parked-cache representation. Used by the pool so both dense and
/// paged detached sets share a single [`DetachedHandle`] space.
#[allow(clippy::large_enum_variant)]
pub(super) enum ParkedCache {
Dense(super::detach::DetachedCacheSet),
Paged(DetachedPagedCacheSet),
Expand Down Expand Up @@ -434,6 +435,7 @@ impl CachePool {
///
/// Retained block pins are preserved across failure paths so the caller
/// can still call [`CachePool::release_detached_paged`] manually or retry.
#[allow(clippy::result_large_err)]
pub fn adopt_paged_preserving(
&mut self,
model: &dyn crate::generate::LanguageModel,
Expand Down
4 changes: 2 additions & 2 deletions src/lib/mlxcel-core/src/cache/turbo/boundary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,9 @@ mod tests {
// Cross-check the single-layer helper against the bulk helper.
let n = 16;
let bulk = resolve_layer_modes(KVCacheMode::Turbo4Asym, n, 2);
for i in 0..n {
for (i, &expected) in bulk.iter().enumerate().take(n) {
let single = resolve_layer_mode(KVCacheMode::Turbo4Asym, i, n, 2);
assert_eq!(bulk[i], single, "layer {i} mismatch");
assert_eq!(expected, single, "layer {i} mismatch");
}
}
}
2 changes: 1 addition & 1 deletion src/lib/mlxcel-core/src/cache/turbo/codebook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ mod tests {

// Φ(−1) ≈ 0.15865525393145702
let v = gaussian_cdf(-1.0);
assert!((v - 0.158_655_253_931_457_0).abs() < 1e-12, "Φ(-1) = {v}");
assert!((v - 0.158_655_253_931_457).abs() < 1e-12, "Φ(-1) = {v}");
}

#[test]
Expand Down
8 changes: 4 additions & 4 deletions src/lib/mlxcel-core/src/cache/turbo/pack3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ pub const V_BIT_WIDTH_3: u8 = 3;
#[inline]
pub fn packed_bytes_per_token_3bit(head_dim: i32) -> i32 {
debug_assert!(
head_dim > 0 && (head_dim as usize) % COORDS_PER_GROUP == 0,
head_dim > 0 && (head_dim as usize).is_multiple_of(COORDS_PER_GROUP),
"packed_bytes_per_token_3bit: head_dim must be a positive multiple of 8; \
got {head_dim}"
);
Expand All @@ -109,7 +109,7 @@ pub fn packed_bytes_per_token_3bit(head_dim: i32) -> i32 {
/// pipeline).
pub fn pack_3bit_indices(indices: &[u8], out: &mut [u8]) {
debug_assert!(
indices.len() % COORDS_PER_GROUP == 0,
indices.len().is_multiple_of(COORDS_PER_GROUP),
"pack_3bit_indices: index count ({}) must be a multiple of {}",
indices.len(),
COORDS_PER_GROUP
Expand Down Expand Up @@ -147,7 +147,7 @@ pub fn pack_3bit_indices(indices: &[u8], out: &mut [u8]) {
/// dequant path).
pub fn unpack_3bit_indices(packed: &[u8], out: &mut [u8]) {
debug_assert!(
packed.len() % BYTES_PER_GROUP == 0,
packed.len().is_multiple_of(BYTES_PER_GROUP),
"unpack_3bit_indices: packed byte count ({}) must be a multiple of {}",
packed.len(),
BYTES_PER_GROUP
Expand Down Expand Up @@ -311,7 +311,7 @@ mod tests {
#[test]
fn round_trip_random_multi_group() {
// 32 indices = 4 groups = 12 bytes packed. Use a deterministic LCG.
let mut state: u32 = 0xC0FFEE_42;
let mut state: u32 = 0xC0FF_EE42;
let mut indices = vec![0u8; 32];
for x in &mut indices {
state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
Expand Down
10 changes: 4 additions & 6 deletions src/lib/mlxcel-core/src/cache/turbo/quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ fn quantize_into_packed(
let idx_off = tok * coords_per_token;
let mut sum_sq = 0.0_f32;
for d_i in 0..coords_per_token {
let c = centroids[indices[idx_off + d_i] as usize];
let c = centroids[indices[idx_off + d_i]];
sum_sq += c * c;
}
let y_hat_norm = sum_sq.sqrt();
Expand Down Expand Up @@ -874,8 +874,7 @@ mod tests {
let hd_usize = head_dim as usize;
let mut k_data: Vec<f32> = Vec::with_capacity(n_tokens * hd_usize);
let token_scales = [0.5_f32, 1.5, 4.0, 12.0];
for tok in 0..n_tokens {
let scale = token_scales[tok];
for &scale in token_scales.iter().take(n_tokens) {
for _ in 0..head_dim {
let mut acc = 0.0_f32;
for _ in 0..6 {
Expand Down Expand Up @@ -981,8 +980,7 @@ mod tests {
// norm storage, which is acceptable in production but would noise
// up this test.
let token_scales = [0.5_f32, 1.5, 4.0, 12.0];
for tok in 0..4 {
let scale = token_scales[tok];
for &scale in &token_scales {
for _ in 0..head_dim {
// Pseudo-Gaussian via summed uniform bits → ~normal in [-1, 1].
let mut acc = 0.0_f32;
Expand Down Expand Up @@ -1057,7 +1055,7 @@ mod tests {
let head_dim: i32 = 128;
let params = TurboQuantParams::new(head_dim as u32, 0xC1_DECAFu32);

let mut rng = Lcg32::new(0xC0FFEE_42);
let mut rng = Lcg32::new(0xC0FF_EE42);
let n_tokens = 4usize;
let hd_usize = head_dim as usize;
let mut v_data: Vec<f32> = Vec::with_capacity(n_tokens * hd_usize);
Expand Down
5 changes: 2 additions & 3 deletions src/lib/mlxcel-core/src/cache/turbo/quant3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl TurboQuantParams3 {
"TurboQuantParams3::new: head_dim must be a non-zero power of two; got {head_dim}"
);
assert!(
head_dim % 8 == 0,
head_dim.is_multiple_of(8),
"TurboQuantParams3::new: head_dim must be a multiple of 8 \
for the 24-bit packing layout; got {head_dim}"
);
Expand Down Expand Up @@ -373,8 +373,7 @@ mod tests {
let mut state: u32 = 123;
let mut v_data: Vec<f32> = Vec::with_capacity(4 * head_dim as usize);
let token_scales = [0.5_f32, 1.5, 4.0, 12.0];
for tok in 0..4 {
let scale = token_scales[tok];
for &scale in &token_scales {
for _ in 0..head_dim {
let mut acc = 0.0_f32;
for _ in 0..6 {
Expand Down
2 changes: 2 additions & 0 deletions src/lib/mlxcel-core/src/cache/turbo/sparse_v.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ pub fn compute_alive_mask(
/// Used by: tests under `cache/turbo_tests.rs` (issue #480 unit tests).
/// Production attention call sites continue to use the standard
/// `attention()` path until the Metal kernel lands.
#[allow(clippy::too_many_arguments)]
pub fn attention_sparse_v_turbo4(
q: &MlxArray,
k: &MlxArray,
Expand Down Expand Up @@ -526,6 +527,7 @@ pub fn attention_sparse_v_turbo4(
/// case.
///
/// Used by: [`KVCache::sparse_v_attention`] (issue #505).
#[allow(clippy::too_many_arguments)]
pub fn attention_sparse_v_turbo4_fused(
q: &MlxArray,
k: &MlxArray,
Expand Down
22 changes: 11 additions & 11 deletions src/lib/mlxcel-core/src/cache/turbo_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ fn turbo4_asym_clone_handle_install_then_dequant_matches_pre_detach() {
let v_data: Vec<f32> = (0..2 * head_dim)
.map(|i| (((i * 7) % 13) as f32 / 13.0) - 0.5)
.collect();
let k = ffi::from_slice_f32(&k_data, &[1, 1, 2, head_dim as i32]);
let v = ffi::from_slice_f32(&v_data, &[1, 1, 2, head_dim as i32]);
let k = ffi::from_slice_f32(&k_data, &[1, 1, 2, head_dim]);
let v = ffi::from_slice_f32(&v_data, &[1, 1, 2, head_dim]);

let (_k1, v1_out) = cache.update_and_fetch(k, v);
let v1 = flatten_fp32(&v1_out);
Expand Down Expand Up @@ -901,13 +901,13 @@ mod boundary_v {
fn typical_32_layer_split_protects_first_two_and_last_two() {
let n = 32;
let modes = resolve_layer_modes(KVCacheMode::Turbo4Asym, n, 2);
for i in 0..n {
for (i, &actual) in modes.iter().enumerate().take(n) {
let expected = if i < 2 || i >= n - 2 {
KVCacheMode::Fp16
} else {
KVCacheMode::Turbo4Asym
};
assert_eq!(modes[i], expected, "layer {i}");
assert_eq!(actual, expected, "layer {i}");
}
}

Expand All @@ -922,10 +922,10 @@ mod boundary_v {
KVCacheMode::Turbo4Delegated,
] {
let bulk = resolve_layer_modes(mode, n, 2);
for i in 0..n {
for (i, &expected) in bulk.iter().enumerate().take(n) {
let single = resolve_layer_mode(mode, i, n, 2);
assert_eq!(
bulk[i], single,
expected, single,
"{mode:?} layer {i}: bulk vs single helper disagree"
);
}
Expand Down Expand Up @@ -1414,8 +1414,8 @@ fn turbo3_asym_clone_handle_install_then_dequant_matches_pre_detach() {
let v_data: Vec<f32> = (0..2 * head_dim)
.map(|i| (((i * 7) % 13) as f32 / 13.0) - 0.5)
.collect();
let k = ffi::from_slice_f32(&k_data, &[1, 1, 2, head_dim as i32]);
let v = ffi::from_slice_f32(&v_data, &[1, 1, 2, head_dim as i32]);
let k = ffi::from_slice_f32(&k_data, &[1, 1, 2, head_dim]);
let v = ffi::from_slice_f32(&v_data, &[1, 1, 2, head_dim]);

let (_k1, v1_out) = cache.update_and_fetch(k, v);
let v1 = flatten_fp32(&v1_out);
Expand Down Expand Up @@ -1667,7 +1667,7 @@ fn turbo3_asym_layer_modes_apply_boundary_upgrade() {
assert_eq!(modes.len(), 8);
// First 2 + last 2 are upgraded to FP16; middle 4 stay Turbo3Asym.
for (i, m) in modes.iter().enumerate() {
if i < 2 || i >= 6 {
if !(2..6).contains(&i) {
assert_eq!(*m, KVCacheMode::Fp16, "layer {i} must be FP16 boundary");
} else {
assert_eq!(
Expand Down Expand Up @@ -2608,8 +2608,8 @@ fn delegated_dequant_sdpa_matches_reference_attention() {
/// Steel-envelope vs cold-only fused composition parity (issue #531).
///
/// `update_and_turbo4_delegated_attention` first tries the steel-envelope
/// kernel (issue #531, single Metal dispatch covering softmax + cold-V dequant
/// + hot-V accumulate). On the same hardware where the cold-only fused
/// kernel (issue #531, single Metal dispatch covering softmax, cold-V dequant,
/// and hot-V accumulate). On the same hardware where the cold-only fused
/// composition path (issue #528) already produces correct output, the
/// steel-envelope path must produce numerically equivalent output (within
/// FP16 round-off) — both paths funnel through the same MLX softmax algebra
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ mod tests {
///
/// `argmax_fn(row, position, prev_token) -> next_token`. Tests seed
/// this so the accept pattern is deterministic.
#[allow(clippy::type_complexity)]
struct SyntheticTarget {
argmax_fn: Box<SyntheticArgmaxFn>,
concat_hidden_dim: i32,
Expand Down
11 changes: 7 additions & 4 deletions src/lib/mlxcel-core/src/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ pub trait LanguageModel {
/// amortizes weight-loading bandwidth across the batch.
///
/// Used by: BatchScheduler (server continuous batching)
#[allow(clippy::needless_range_loop)]
fn forward_batched(
&self,
input_ids: &MlxArray,
Expand Down Expand Up @@ -822,7 +823,7 @@ impl CxxGenerator {
let n_layers = self.caches.len();
let requested = crate::cache::turbo::boundary_v_layers_from_env();
let layer_modes = crate::cache::turbo::resolve_layer_modes(nominal, n_layers, requested);
for (cache, mode) in self.caches.iter_mut().zip(layer_modes.into_iter()) {
for (cache, mode) in self.caches.iter_mut().zip(layer_modes) {
cache.mode = mode;
}
}
Expand Down Expand Up @@ -2096,7 +2097,7 @@ mod tests {
// n=0 is the very first decode iteration after prefill. Clearing here
// would discard tensors needed for the pipelined next-step computation.
let n = 0_usize;
assert!(!(n % 256 == 0 && n > 0));
assert!(!(n.is_multiple_of(256) && n > 0));
}

#[test]
Expand Down Expand Up @@ -2259,8 +2260,10 @@ mod tests {
let caller_bias = make_bias(&[(99, -3.0)]);
let g = CxxGenerator::new(2).with_token_bias(cached);

let mut caller = SamplingConfig::default();
caller.token_bias = caller_bias;
let caller = SamplingConfig {
token_bias: caller_bias,
..SamplingConfig::default()
};
let composed = g.compose_sampling(&caller);

// Caller's explicit bias is preserved verbatim, cached bias is ignored.
Expand Down
11 changes: 10 additions & 1 deletion src/lib/mlxcel-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,7 @@ impl FusedQKVLinear {
/// Preserves both quantization `biases` and true linear `bias` tensors
/// when present; Qwen2-family checkpoints require q/k/v linear bias for
/// sane logits.
#[allow(clippy::too_many_arguments)]
pub fn from_weights_separate_with_mode(
weights: &crate::weights::WeightMap,
prefix: &str,
Expand Down Expand Up @@ -1773,6 +1774,7 @@ fn na_attention_log_mode() -> NaAttentionLogMode {
})
}

#[allow(clippy::too_many_arguments)]
fn classify_na_attention_dispatch(
q: &MlxArray,
k: &MlxArray,
Expand Down Expand Up @@ -1813,6 +1815,7 @@ fn classify_na_attention_dispatch(
(route, fast_path_eligible, nax_eligible)
}

#[allow(clippy::too_many_arguments)]
fn na_attention_eligibility_reasons(
q: &MlxArray,
k: &MlxArray,
Expand Down Expand Up @@ -1879,6 +1882,7 @@ fn na_attention_eligibility_reasons(
(fast_reason, nax_reason)
}

#[allow(clippy::too_many_arguments)]
fn record_na_attention_dispatch(
q: &MlxArray,
k: &MlxArray,
Expand All @@ -1894,7 +1898,7 @@ fn record_na_attention_dispatch(
let count = DISPATCH_COUNT.fetch_add(1, Ordering::Relaxed) + 1;
let should_log = match na_attention_log_mode() {
NaAttentionLogMode::Off => false,
NaAttentionLogMode::Sampled => count <= 8 || count % 100 == 0,
NaAttentionLogMode::Sampled => count <= 8 || count.is_multiple_of(100),
NaAttentionLogMode::All => true,
};
if should_log {
Expand Down Expand Up @@ -1993,6 +1997,11 @@ pub fn attention(
/// Pointer-friendly attention wrapper for existing model call sites.
///
/// Used by: Model attention call sites that still store masks as raw pointers
///
/// # Safety
///
/// `mask` must be either null or point to a valid, live `MlxArray`. The
/// referent must remain valid for the duration of this call.
pub unsafe fn attention_from_ptr(
q: &MlxArray,
k: &MlxArray,
Expand Down
7 changes: 3 additions & 4 deletions src/lib/mlxcel-core/src/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -977,11 +977,10 @@ mod tests {
"top token should survive nucleus (got {})",
v[0]
);
for i in 1..5 {
for (i, val) in v.iter().enumerate().take(5).skip(1) {
assert!(
v[i].is_infinite() && v[i] < 0.0,
"token {i} should be filtered to -inf (got {})",
v[i]
val.is_infinite() && *val < 0.0,
"token {i} should be filtered to -inf (got {val})",
);
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/lib/mlxcel-core/src/speculative/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,10 @@ mod tests {
let caller_bias = make_bias(&[(42, f32::NEG_INFINITY)]);
let g = SpeculativeGenerator::new(2, 1).with_token_bias(cached);

let mut caller = SamplingConfig::default();
caller.token_bias = caller_bias;
let caller = SamplingConfig {
token_bias: caller_bias,
..SamplingConfig::default()
};
let target = g.compose_target_sampling(&caller);

assert_eq!(
Expand Down
10 changes: 5 additions & 5 deletions src/lib/mlxcel-core/src/speculative/mtp/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
//! arm its `set_shared_kv` with the seed slabs.
//! 3. **Round loop.** While more tokens remain:
//! a. Drafter produces `K-1` proposals via
//! [`crate::drafter::Drafter::draft_block`] (autoregressive, with
//! RoPE queries frozen at the bonus token's absolute position).
//! [`crate::drafter::Drafter::draft_block`] (autoregressive, with
//! RoPE queries frozen at the bonus token's absolute position).
//! b. Target verify on `[bonus, draft_0, …, draft_{K-2}]` via
//! [`MtpTarget::verify`] — produces `target_tokens`, the next
//! hidden, and the re-sliced shared K/V.
//! [`MtpTarget::verify`] — produces `target_tokens`, the next
//! hidden, and the re-sliced shared K/V.
//! c. Compare draft vs target via [`super::speculative_walk`].
//! d. Emit `new_tokens`. Update `bonus` to the last emitted token.
//! e. Rebind the drafter against the new shared K/V via
//! [`crate::drafter::Drafter::set_shared_kv`].
//! [`crate::drafter::Drafter::set_shared_kv`].
//! 4. **Termination.** Loop exits when an emitted token is in the
//! target's `eos_token_ids()` or `emitted >= max_tokens`.

Expand Down
7 changes: 1 addition & 6 deletions src/lib/mlxcel-core/src/speculative/mtp/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,12 +579,7 @@ fn greedy_parity_perfect_drafter_matches_no_drafter_baseline_32_tokens() {
let scripted_target: Vec<Vec<i32>> = (0..11)
.map(|r| {
let base = 1000 + r * 4;
vec![
base as i32,
(base + 1) as i32,
(base + 2) as i32,
(base + 3) as i32,
]
vec![base, (base + 1), (base + 2), (base + 3)]
})
.collect();
let scripted_draft: Vec<Vec<i32>> = scripted_target
Expand Down
Loading