Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ What did you run to convince yourself this works? Be specific.
For inference changes, real-checkpoint validation is required — synthetic-only is not enough.
-->

- [ ] `cargo fmt --check`
- [ ] `cargo fmt --all -- --check` (enforced by CI — violations block merge)
- [ ] `cargo clippy --all-targets -- -D warnings`
- [ ] `cargo test --release`
- [ ] `cargo deny check`
Expand Down
24 changes: 20 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# PR / push CI for security audit. Builds and tests are covered by
# PR / push CI for lightweight quality gates. Builds and tests are covered by
# release.yml (signed artifacts) and pipeline-parallel-ci.yml (distributed
# runs); this workflow runs the lightweight cargo-deny gate on every
# touched-Rust change so license drift and new advisories are caught at
# PR time.
# runs); this workflow runs cargo-deny (license / advisory) and cargo-fmt
# checks on every touched-Rust change so formatting drift and license issues
# are caught at PR time.

name: CI

Expand Down Expand Up @@ -54,3 +54,19 @@ jobs:
with:
command: check
log-level: warn

fmt:
name: cargo-fmt
needs: changes
if: needs.changes.outputs.rust == 'true'
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v6
with:
persist-credentials: false
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- run: cargo fmt --all -- --check
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ Thank you for your interest in contributing to mlxcel! This document covers the
```
4. Run the local quality gates:
```bash
cargo fmt --check
cargo fmt --all -- --check # enforced by CI; fmt violations block merge
cargo clippy --all-targets -- -D warnings
cargo deny check # advisories + licenses + sources
cargo deny check # advisories + licenses + sources
```
5. For inference changes, validate against a real checkpoint — synthetic or build-only validation is not enough (see [`AGENTS.md`](AGENTS.md) for why).
6. Commit with a conventional prefix (see below) and a clear message.
Expand Down
29 changes: 10 additions & 19 deletions src/lib/mlxcel-core/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2373,10 +2373,8 @@ impl KVCache {

let k_slice = ffi::slice(k_int8, &[0, 0, 0, 0], &[ks[0], ks[1], live_len, ks[3]]);
let v_slice = ffi::slice(v_int8, &[0, 0, 0, 0], &[vs[0], vs[1], live_len, vs[3]]);
let ks_slice =
ffi::slice(k_scales, &[0, 0, 0, 0], &[kss[0], kss[1], live_len, 1]);
let vs_slice =
ffi::slice(v_scales, &[0, 0, 0, 0], &[vss[0], vss[1], live_len, 1]);
let ks_slice = ffi::slice(k_scales, &[0, 0, 0, 0], &[kss[0], kss[1], live_len, 1]);
let vs_slice = ffi::slice(v_scales, &[0, 0, 0, 0], &[vss[0], vss[1], live_len, 1]);

(
dequantize(&k_slice, &ks_slice),
Expand Down Expand Up @@ -3785,8 +3783,7 @@ impl RotatingKVCache {
return (new_keys, new_values);
}

let (base_k, base_v, current_seq_len) =
self.visible_fp16_prefix_for_concat();
let (base_k, base_v, current_seq_len) = self.visible_fp16_prefix_for_concat();

let concat_k = concatenate(&base_k, &new_keys, 2);
let concat_v = concatenate(&base_v, &new_values, 2);
Expand Down Expand Up @@ -5681,10 +5678,7 @@ mod tests {
.collect::<Vec<_>>()
};
assert_eq!(to_f32(&visible_keys), vec![1.0, 5.0, 6.0, 7.0, 8.0]);
assert_eq!(
to_f32(&visible_values),
vec![10.0, 50.0, 60.0, 70.0, 80.0]
);
assert_eq!(to_f32(&visible_values), vec![10.0, 50.0, 60.0, 70.0, 80.0]);
}

#[test]
Expand Down Expand Up @@ -6190,10 +6184,8 @@ mod tests {
}
let (q_unrot, _) = unit_token(42);
let q_ref = rotate_at(&q_unrot, M);
let (k_ref, v_ref) = cache_ref.update_and_fetch(
rotate_at(&unit_token(99).0, M),
unit_token(99).1,
);
let (k_ref, v_ref) =
cache_ref.update_and_fetch(rotate_at(&unit_token(99).0, M), unit_token(99).1);
let out_ref = mlxcel_core::causal_attention(&q_ref, &k_ref, &v_ref, scale, 0.0, 0);
let out_ref_f32 = to_f32(&out_ref);

Expand All @@ -6216,11 +6208,10 @@ mod tests {
// position `M` to simulate the pre-fix offset decrement.
assert_eq!(cache_broken.trim_front(N), N);
let q_broken = rotate_at(&q_unrot, M);
let (k_broken, v_broken) = cache_broken.update_and_fetch(
rotate_at(&unit_token(99).0, M),
unit_token(99).1,
);
let out_broken = mlxcel_core::causal_attention(&q_broken, &k_broken, &v_broken, scale, 0.0, 0);
let (k_broken, v_broken) =
cache_broken.update_and_fetch(rotate_at(&unit_token(99).0, M), unit_token(99).1);
let out_broken =
mlxcel_core::causal_attention(&q_broken, &k_broken, &v_broken, scale, 0.0, 0);
let out_broken_f32 = to_f32(&out_broken);

let mut sq_err = 0.0_f64;
Expand Down
21 changes: 8 additions & 13 deletions src/lib/mlxcel-core/src/cache/detach.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,14 @@ impl DetachedKVCache {

// Slice axis 2 (seq-len) of each tensor to the new length. Width axes
// (B, H, head_dim / packed_dim) come from the existing shape.
let trim_axis_seq = |a: &Option<UniquePtr<MlxArray>>,
tail_axis: i32|
-> Option<UniquePtr<MlxArray>> {
a.as_ref().map(|arr| {
let shape = ffi::array_shape(arr);
let last = if tail_axis == 0 { shape[3] } else { tail_axis };
ffi::slice(
arr,
&[0, 0, 0, 0],
&[shape[0], shape[1], new_len, last],
)
})
};
let trim_axis_seq =
|a: &Option<UniquePtr<MlxArray>>, tail_axis: i32| -> Option<UniquePtr<MlxArray>> {
a.as_ref().map(|arr| {
let shape = ffi::array_shape(arr);
let last = if tail_axis == 0 { shape[3] } else { tail_axis };
ffi::slice(arr, &[0, 0, 0, 0], &[shape[0], shape[1], new_len, last])
})
};

if self.mode == KVCacheMode::Turbo4Delegated {
// K is unified — same shape contract as Fp16. Slice to `new_len`.
Expand Down
18 changes: 8 additions & 10 deletions src/lib/mlxcel-core/src/cache/detach_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -677,10 +677,7 @@ fn detached_kvcache_trim_to_mid_buffer_resets_offset_and_visible_keys() {
#[test]
fn detached_kvcache_trim_to_zero_drops_storage_and_keeps_mode() {
let mut live = KVCache::new_with_mode(KVCacheMode::Int8);
live.update(
fp32_tokens(&[1.0, 2.0, 3.0]),
fp32_tokens(&[4.0, 5.0, 6.0]),
);
live.update(fp32_tokens(&[1.0, 2.0, 3.0]), fp32_tokens(&[4.0, 5.0, 6.0]));
let mut handle = live.clone_handle();
assert_eq!(handle.seq_len(), 3);
assert_eq!(handle.mode(), KVCacheMode::Int8);
Expand All @@ -698,7 +695,9 @@ fn detached_kvcache_trim_to_exact_offset_is_noop() {
live.update(fp32_tokens(&[1.0, 2.0]), fp32_tokens(&[3.0, 4.0]));
let mut handle = live.clone_handle();

handle.trim_to(2).expect("trim_to exact offset must succeed");
handle
.trim_to(2)
.expect("trim_to exact offset must succeed");
assert_eq!(handle.seq_len(), 2);

let mut restored = KVCache::new();
Expand Down Expand Up @@ -825,7 +824,9 @@ fn detached_cache_set_truncate_to_zero_drops_every_layer() {
let v: Vec<f32> = (1..=4).map(|i| i as f32 * 10.0).collect();
let mut detached = detached_set_with_fp16_layers(2, &k, &v);

detached.truncate_to(0).expect("truncate_to zero must succeed");
detached
.truncate_to(0)
.expect("truncate_to zero must succeed");
assert!(detached.caches.iter().all(|c| c.is_empty()));
assert_eq!(detached.current_offset, 0);
assert_eq!(detached.prompt_len, 0);
Expand Down Expand Up @@ -940,10 +941,7 @@ fn detached_cache_set_truncate_to_int8_preserves_dequantization() {
let seq_adopted = pool.adopt(&model, detached).unwrap();
let (k_a, v_a) = {
let caches = pool.get_caches_mut(seq_adopted).unwrap();
caches[0].update_and_fetch(
fp32_tokens(&full_k[5..]),
fp32_tokens(&full_v[5..]),
)
caches[0].update_and_fetch(fp32_tokens(&full_k[5..]), fp32_tokens(&full_v[5..]))
};
eval(&k_a);
eval(&v_a);
Expand Down
1 change: 0 additions & 1 deletion src/lib/mlxcel-core/src/drafter/dflash/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,4 +441,3 @@ mod tests {
);
}
}

42 changes: 16 additions & 26 deletions src/lib/mlxcel-core/src/drafter/dflash/drafter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,11 @@ impl DFlashDrafter {
path: config_path.display().to_string(),
source: e,
})?;
let config = DFlashConfig::from_json(&config_json).map_err(|e| {
DrafterError::ConfigParse {
let config =
DFlashConfig::from_json(&config_json).map_err(|e| DrafterError::ConfigParse {
path: config_path.display().to_string(),
source: serde::de::Error::custom(e),
}
})?;
})?;

let mut weights = crate::weights::load_weights_from_dir(path)
.map_err(|msg| DrafterError::LoadFailed { reason: msg })?;
Expand Down Expand Up @@ -209,8 +208,9 @@ impl Drafter for DFlashDrafter {
// expose `embed_tokens` (e.g. a non-Qwen-3.5 model fed in by
// mistake).
if self.model.needs_embed_binding() {
let embed = target.embed_tokens_module().ok_or_else(|| {
DrafterError::BindFailed {
let embed = target
.embed_tokens_module()
.ok_or_else(|| DrafterError::BindFailed {
reason: format!(
"DFlash drafter checkpoint omits embed_tokens.weight \
and the target does not expose embed_tokens_module(); \
Expand All @@ -220,8 +220,7 @@ impl Drafter for DFlashDrafter {
(kind = {})",
self.kind()
),
}
})?;
})?;
self.model.bind_target_embedding(embed);
} else {
// Legacy capability smoke-test for self-contained checkpoints:
Expand Down Expand Up @@ -427,10 +426,7 @@ fn sample_block_per_position_batched(
sampler: &SamplingConfig,
) -> Result<Vec<Vec<i32>>, DrafterError> {
let shape = ffi::array_shape(logits);
if shape.len() != 3
|| shape[0] != batch_size as i32
|| shape[1] != block_size as i32
{
if shape.len() != 3 || shape[0] != batch_size as i32 || shape[1] != block_size as i32 {
return Err(DrafterError::DraftFailed {
reason: format!(
"DFlash drafter (batched) expected logits shape \
Expand Down Expand Up @@ -468,11 +464,7 @@ fn sample_block_per_position_batched(
for i in 0..n {
// Row `(b, i+1)` of the [B, L, V] logits.
let pos = (i + 1) as i32;
let row = ffi::slice(
logits,
&[b, pos, 0_i32],
&[b + 1, pos + 1, vocab],
);
let row = ffi::slice(logits, &[b, pos, 0_i32], &[b + 1, pos + 1, vocab]);
// Drop the seq axis so we get a `[1, vocab]` 2D slice (fused_sample
// / argmax expect `[batch, vocab]`).
let row = ffi::reshape(&row, &[1_i32, vocab]);
Expand Down Expand Up @@ -536,7 +528,11 @@ fn sample_block_per_position(
for i in 0..n {
// Row `i + 1` of the [1, L, V] logits.
let row_idx = (i + 1) as i32;
let row = ffi::slice(logits, &[0_i32, row_idx, 0_i32], &[1_i32, row_idx + 1, vocab]);
let row = ffi::slice(
logits,
&[0_i32, row_idx, 0_i32],
&[1_i32, row_idx + 1, vocab],
);
// Drop the seq axis so we get a `[1, vocab]` 2D slice (fused_sample
// / argmax expect `[batch, vocab]`).
let row = ffi::reshape(&row, &[1_i32, vocab]);
Expand Down Expand Up @@ -588,10 +584,7 @@ fn sample_block_per_position_array(
}

let tokens = sample_block_per_position(logits, block_size, sampler)?;
Ok(ffi::from_slice_i32(
&tokens,
&[1, (block_size - 1) as i32],
))
Ok(ffi::from_slice_i32(&tokens, &[1, (block_size - 1) as i32]))
}

#[cfg(test)]
Expand Down Expand Up @@ -620,10 +613,7 @@ mod tests {
ffi::zeros(&[4, 4], dtype::BFLOAT16),
);
// A non-bf16 tensor: should pass through.
weights.insert(
"fc.weight".to_string(),
ffi::zeros(&[4, 4], dtype::FLOAT16),
);
weights.insert("fc.weight".to_string(), ffi::zeros(&[4, 4], dtype::FLOAT16));

convert_bf16_to_f16_non_quantized(&mut weights);

Expand Down
9 changes: 2 additions & 7 deletions src/lib/mlxcel-core/src/drafter/dflash/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,13 @@ impl DFlashDecoderLayer {
let post_attention_layernorm_w = weights
.get(&format!("{prefix}.post_attention_layernorm.weight"))
.map(|w| ffi::copy(w))
.ok_or_else(|| {
format!("Weight not found: {prefix}.post_attention_layernorm.weight")
})?;
.ok_or_else(|| format!("Weight not found: {prefix}.post_attention_layernorm.weight"))?;

Ok(Self {
self_attn,
mlp,
input_layernorm: RMSNorm::new(input_layernorm_w, config.rms_norm_eps),
post_attention_layernorm: RMSNorm::new(
post_attention_layernorm_w,
config.rms_norm_eps,
),
post_attention_layernorm: RMSNorm::new(post_attention_layernorm_w, config.rms_norm_eps),
})
}
}
24 changes: 6 additions & 18 deletions src/lib/mlxcel-core/src/drafter/dflash/mlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,12 @@ impl DFlashMlp {
group_size: i32,
bits: i32,
) -> Result<Self, String> {
let gate_proj = UnifiedLinear::from_weights(
weights,
&format!("{prefix}.gate_proj"),
group_size,
bits,
)?;
let up_proj = UnifiedLinear::from_weights(
weights,
&format!("{prefix}.up_proj"),
group_size,
bits,
)?;
let down_proj = UnifiedLinear::from_weights(
weights,
&format!("{prefix}.down_proj"),
group_size,
bits,
)?;
let gate_proj =
UnifiedLinear::from_weights(weights, &format!("{prefix}.gate_proj"), group_size, bits)?;
let up_proj =
UnifiedLinear::from_weights(weights, &format!("{prefix}.up_proj"), group_size, bits)?;
let down_proj =
UnifiedLinear::from_weights(weights, &format!("{prefix}.down_proj"), group_size, bits)?;
Ok(Self {
gate_proj,
up_proj,
Expand Down
3 changes: 1 addition & 2 deletions src/lib/mlxcel-core/src/drafter/dflash/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ pub(crate) fn materialize_argmax_i32_vec(argmax: &MlxArray, expected_len: usize)
.take(expected_len)
.map(|chunk| {
i64::from_ne_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
]) as i32
})
.collect(),
Expand Down
16 changes: 11 additions & 5 deletions src/lib/mlxcel-core/src/drafter/dflash/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,14 @@ mod tests {
fn tiny_weights_without_embed() -> WeightMap {
let mut w: WeightMap = std::collections::HashMap::new();
// fc: [hidden, num_target_layers * hidden] = [4, 16]
w.insert("fc.weight".to_string(), ffi::zeros(&[4, 16], dtype::FLOAT32));
w.insert("hidden_norm.weight".to_string(), ffi::zeros(&[4], dtype::FLOAT32));
w.insert(
"fc.weight".to_string(),
ffi::zeros(&[4, 16], dtype::FLOAT32),
);
w.insert(
"hidden_norm.weight".to_string(),
ffi::zeros(&[4], dtype::FLOAT32),
);
w.insert("norm.weight".to_string(), ffi::zeros(&[4], dtype::FLOAT32));
// Layer 0 projections. q out = n_heads*head_dim = 4; k/v out =
// n_kv_heads*head_dim = 2; o in = 4.
Expand Down Expand Up @@ -593,9 +599,9 @@ mod tests {

// Stand-in for the target's embedding table — a regular
// [vocab, hidden] tensor, the same shape Qwen 3.5 hands out.
let target_embed = crate::layers::UnifiedEmbedding::Regular(
crate::layers::Embedding::new(ffi::zeros(&[8, 4], dtype::FLOAT32)),
);
let target_embed = crate::layers::UnifiedEmbedding::Regular(crate::layers::Embedding::new(
ffi::zeros(&[8, 4], dtype::FLOAT32),
));
model.bind_target_embedding(target_embed);

assert!(
Expand Down
Loading