Skip to content

Commit a1b9eaa

Browse files
authored
chore: reformat tree with cargo fmt and add fmt to CI gate (#13)
Closes #4. Commit 1: `cargo fmt --all` reformat across the workspace (49 files, whitespace / line-break reflows / trivial argument alignment, no semantic changes). Commit 2: New `fmt` job in `.github/workflows/ci.yml` mirrored after the existing `deny` job (needs: changes, ubuntu-latest, permissions: contents: read, persist-credentials: false, rustfmt component, then `cargo fmt --all -- --check`). CONTRIBUTING.md and .github/PULL_REQUEST_TEMPLATE.md updated so the local quality-gate list and PR checklist explicitly note the enforced fmt command. The new fmt job verified itself against this PR before merge (the push triggered ci.yml with the freshly-added gate; SUCCESS). Future PRs will trip the gate immediately and no further fmt drift can accumulate.
1 parent cb2002b commit a1b9eaa

52 files changed

Lines changed: 563 additions & 798 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ What did you run to convince yourself this works? Be specific.
3333
For inference changes, real-checkpoint validation is required — synthetic-only is not enough.
3434
-->
3535

36-
- [ ] `cargo fmt --check`
36+
- [ ] `cargo fmt --all -- --check` (enforced by CI — violations block merge)
3737
- [ ] `cargo clippy --all-targets -- -D warnings`
3838
- [ ] `cargo test --release`
3939
- [ ] `cargo deny check`

.github/workflows/ci.yml

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# PR / push CI for security audit. Builds and tests are covered by
1+
# PR / push CI for lightweight quality gates. Builds and tests are covered by
22
# release.yml (signed artifacts) and pipeline-parallel-ci.yml (distributed
3-
# runs); this workflow runs the lightweight cargo-deny gate on every
4-
# touched-Rust change so license drift and new advisories are caught at
5-
# PR time.
3+
# runs); this workflow runs cargo-deny (license / advisory) and cargo-fmt
4+
# checks on every touched-Rust change so formatting drift and license issues
5+
# are caught at PR time.
66

77
name: CI
88

@@ -54,3 +54,19 @@ jobs:
5454
with:
5555
command: check
5656
log-level: warn
57+
58+
fmt:
59+
name: cargo-fmt
60+
needs: changes
61+
if: needs.changes.outputs.rust == 'true'
62+
runs-on: ubuntu-latest
63+
permissions:
64+
contents: read
65+
steps:
66+
- uses: actions/checkout@v6
67+
with:
68+
persist-credentials: false
69+
- uses: dtolnay/rust-toolchain@stable
70+
with:
71+
components: rustfmt
72+
- run: cargo fmt --all -- --check

CONTRIBUTING.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ Thank you for your interest in contributing to mlxcel! This document covers the
4141
```
4242
4. Run the local quality gates:
4343
```bash
44-
cargo fmt --check
44+
cargo fmt --all -- --check # enforced by CI; fmt violations block merge
4545
cargo clippy --all-targets -- -D warnings
46-
cargo deny check # advisories + licenses + sources
46+
cargo deny check # advisories + licenses + sources
4747
```
4848
5. For inference changes, validate against a real checkpoint — synthetic or build-only validation is not enough (see [`AGENTS.md`](AGENTS.md) for why).
4949
6. Commit with a conventional prefix (see below) and a clear message.

src/lib/mlxcel-core/src/cache.rs

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2373,10 +2373,8 @@ impl KVCache {
23732373

23742374
let k_slice = ffi::slice(k_int8, &[0, 0, 0, 0], &[ks[0], ks[1], live_len, ks[3]]);
23752375
let v_slice = ffi::slice(v_int8, &[0, 0, 0, 0], &[vs[0], vs[1], live_len, vs[3]]);
2376-
let ks_slice =
2377-
ffi::slice(k_scales, &[0, 0, 0, 0], &[kss[0], kss[1], live_len, 1]);
2378-
let vs_slice =
2379-
ffi::slice(v_scales, &[0, 0, 0, 0], &[vss[0], vss[1], live_len, 1]);
2376+
let ks_slice = ffi::slice(k_scales, &[0, 0, 0, 0], &[kss[0], kss[1], live_len, 1]);
2377+
let vs_slice = ffi::slice(v_scales, &[0, 0, 0, 0], &[vss[0], vss[1], live_len, 1]);
23802378

23812379
(
23822380
dequantize(&k_slice, &ks_slice),
@@ -3785,8 +3783,7 @@ impl RotatingKVCache {
37853783
return (new_keys, new_values);
37863784
}
37873785

3788-
let (base_k, base_v, current_seq_len) =
3789-
self.visible_fp16_prefix_for_concat();
3786+
let (base_k, base_v, current_seq_len) = self.visible_fp16_prefix_for_concat();
37903787

37913788
let concat_k = concatenate(&base_k, &new_keys, 2);
37923789
let concat_v = concatenate(&base_v, &new_values, 2);
@@ -5681,10 +5678,7 @@ mod tests {
56815678
.collect::<Vec<_>>()
56825679
};
56835680
assert_eq!(to_f32(&visible_keys), vec![1.0, 5.0, 6.0, 7.0, 8.0]);
5684-
assert_eq!(
5685-
to_f32(&visible_values),
5686-
vec![10.0, 50.0, 60.0, 70.0, 80.0]
5687-
);
5681+
assert_eq!(to_f32(&visible_values), vec![10.0, 50.0, 60.0, 70.0, 80.0]);
56885682
}
56895683

56905684
#[test]
@@ -6190,10 +6184,8 @@ mod tests {
61906184
}
61916185
let (q_unrot, _) = unit_token(42);
61926186
let q_ref = rotate_at(&q_unrot, M);
6193-
let (k_ref, v_ref) = cache_ref.update_and_fetch(
6194-
rotate_at(&unit_token(99).0, M),
6195-
unit_token(99).1,
6196-
);
6187+
let (k_ref, v_ref) =
6188+
cache_ref.update_and_fetch(rotate_at(&unit_token(99).0, M), unit_token(99).1);
61976189
let out_ref = mlxcel_core::causal_attention(&q_ref, &k_ref, &v_ref, scale, 0.0, 0);
61986190
let out_ref_f32 = to_f32(&out_ref);
61996191

@@ -6216,11 +6208,10 @@ mod tests {
62166208
// position `M` to simulate the pre-fix offset decrement.
62176209
assert_eq!(cache_broken.trim_front(N), N);
62186210
let q_broken = rotate_at(&q_unrot, M);
6219-
let (k_broken, v_broken) = cache_broken.update_and_fetch(
6220-
rotate_at(&unit_token(99).0, M),
6221-
unit_token(99).1,
6222-
);
6223-
let out_broken = mlxcel_core::causal_attention(&q_broken, &k_broken, &v_broken, scale, 0.0, 0);
6211+
let (k_broken, v_broken) =
6212+
cache_broken.update_and_fetch(rotate_at(&unit_token(99).0, M), unit_token(99).1);
6213+
let out_broken =
6214+
mlxcel_core::causal_attention(&q_broken, &k_broken, &v_broken, scale, 0.0, 0);
62246215
let out_broken_f32 = to_f32(&out_broken);
62256216

62266217
let mut sq_err = 0.0_f64;

src/lib/mlxcel-core/src/cache/detach.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -241,19 +241,14 @@ impl DetachedKVCache {
241241

242242
// Slice axis 2 (seq-len) of each tensor to the new length. Width axes
243243
// (B, H, head_dim / packed_dim) come from the existing shape.
244-
let trim_axis_seq = |a: &Option<UniquePtr<MlxArray>>,
245-
tail_axis: i32|
246-
-> Option<UniquePtr<MlxArray>> {
247-
a.as_ref().map(|arr| {
248-
let shape = ffi::array_shape(arr);
249-
let last = if tail_axis == 0 { shape[3] } else { tail_axis };
250-
ffi::slice(
251-
arr,
252-
&[0, 0, 0, 0],
253-
&[shape[0], shape[1], new_len, last],
254-
)
255-
})
256-
};
244+
let trim_axis_seq =
245+
|a: &Option<UniquePtr<MlxArray>>, tail_axis: i32| -> Option<UniquePtr<MlxArray>> {
246+
a.as_ref().map(|arr| {
247+
let shape = ffi::array_shape(arr);
248+
let last = if tail_axis == 0 { shape[3] } else { tail_axis };
249+
ffi::slice(arr, &[0, 0, 0, 0], &[shape[0], shape[1], new_len, last])
250+
})
251+
};
257252

258253
if self.mode == KVCacheMode::Turbo4Delegated {
259254
// K is unified — same shape contract as Fp16. Slice to `new_len`.

src/lib/mlxcel-core/src/cache/detach_tests.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -677,10 +677,7 @@ fn detached_kvcache_trim_to_mid_buffer_resets_offset_and_visible_keys() {
677677
#[test]
678678
fn detached_kvcache_trim_to_zero_drops_storage_and_keeps_mode() {
679679
let mut live = KVCache::new_with_mode(KVCacheMode::Int8);
680-
live.update(
681-
fp32_tokens(&[1.0, 2.0, 3.0]),
682-
fp32_tokens(&[4.0, 5.0, 6.0]),
683-
);
680+
live.update(fp32_tokens(&[1.0, 2.0, 3.0]), fp32_tokens(&[4.0, 5.0, 6.0]));
684681
let mut handle = live.clone_handle();
685682
assert_eq!(handle.seq_len(), 3);
686683
assert_eq!(handle.mode(), KVCacheMode::Int8);
@@ -698,7 +695,9 @@ fn detached_kvcache_trim_to_exact_offset_is_noop() {
698695
live.update(fp32_tokens(&[1.0, 2.0]), fp32_tokens(&[3.0, 4.0]));
699696
let mut handle = live.clone_handle();
700697

701-
handle.trim_to(2).expect("trim_to exact offset must succeed");
698+
handle
699+
.trim_to(2)
700+
.expect("trim_to exact offset must succeed");
702701
assert_eq!(handle.seq_len(), 2);
703702

704703
let mut restored = KVCache::new();
@@ -825,7 +824,9 @@ fn detached_cache_set_truncate_to_zero_drops_every_layer() {
825824
let v: Vec<f32> = (1..=4).map(|i| i as f32 * 10.0).collect();
826825
let mut detached = detached_set_with_fp16_layers(2, &k, &v);
827826

828-
detached.truncate_to(0).expect("truncate_to zero must succeed");
827+
detached
828+
.truncate_to(0)
829+
.expect("truncate_to zero must succeed");
829830
assert!(detached.caches.iter().all(|c| c.is_empty()));
830831
assert_eq!(detached.current_offset, 0);
831832
assert_eq!(detached.prompt_len, 0);
@@ -940,10 +941,7 @@ fn detached_cache_set_truncate_to_int8_preserves_dequantization() {
940941
let seq_adopted = pool.adopt(&model, detached).unwrap();
941942
let (k_a, v_a) = {
942943
let caches = pool.get_caches_mut(seq_adopted).unwrap();
943-
caches[0].update_and_fetch(
944-
fp32_tokens(&full_k[5..]),
945-
fp32_tokens(&full_v[5..]),
946-
)
944+
caches[0].update_and_fetch(fp32_tokens(&full_k[5..]), fp32_tokens(&full_v[5..]))
947945
};
948946
eval(&k_a);
949947
eval(&v_a);

src/lib/mlxcel-core/src/drafter/dflash/attention.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,4 +441,3 @@ mod tests {
441441
);
442442
}
443443
}
444-

src/lib/mlxcel-core/src/drafter/dflash/drafter.rs

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,11 @@ impl DFlashDrafter {
106106
path: config_path.display().to_string(),
107107
source: e,
108108
})?;
109-
let config = DFlashConfig::from_json(&config_json).map_err(|e| {
110-
DrafterError::ConfigParse {
109+
let config =
110+
DFlashConfig::from_json(&config_json).map_err(|e| DrafterError::ConfigParse {
111111
path: config_path.display().to_string(),
112112
source: serde::de::Error::custom(e),
113-
}
114-
})?;
113+
})?;
115114

116115
let mut weights = crate::weights::load_weights_from_dir(path)
117116
.map_err(|msg| DrafterError::LoadFailed { reason: msg })?;
@@ -209,8 +208,9 @@ impl Drafter for DFlashDrafter {
209208
// expose `embed_tokens` (e.g. a non-Qwen-3.5 model fed in by
210209
// mistake).
211210
if self.model.needs_embed_binding() {
212-
let embed = target.embed_tokens_module().ok_or_else(|| {
213-
DrafterError::BindFailed {
211+
let embed = target
212+
.embed_tokens_module()
213+
.ok_or_else(|| DrafterError::BindFailed {
214214
reason: format!(
215215
"DFlash drafter checkpoint omits embed_tokens.weight \
216216
and the target does not expose embed_tokens_module(); \
@@ -220,8 +220,7 @@ impl Drafter for DFlashDrafter {
220220
(kind = {})",
221221
self.kind()
222222
),
223-
}
224-
})?;
223+
})?;
225224
self.model.bind_target_embedding(embed);
226225
} else {
227226
// Legacy capability smoke-test for self-contained checkpoints:
@@ -427,10 +426,7 @@ fn sample_block_per_position_batched(
427426
sampler: &SamplingConfig,
428427
) -> Result<Vec<Vec<i32>>, DrafterError> {
429428
let shape = ffi::array_shape(logits);
430-
if shape.len() != 3
431-
|| shape[0] != batch_size as i32
432-
|| shape[1] != block_size as i32
433-
{
429+
if shape.len() != 3 || shape[0] != batch_size as i32 || shape[1] != block_size as i32 {
434430
return Err(DrafterError::DraftFailed {
435431
reason: format!(
436432
"DFlash drafter (batched) expected logits shape \
@@ -468,11 +464,7 @@ fn sample_block_per_position_batched(
468464
for i in 0..n {
469465
// Row `(b, i+1)` of the [B, L, V] logits.
470466
let pos = (i + 1) as i32;
471-
let row = ffi::slice(
472-
logits,
473-
&[b, pos, 0_i32],
474-
&[b + 1, pos + 1, vocab],
475-
);
467+
let row = ffi::slice(logits, &[b, pos, 0_i32], &[b + 1, pos + 1, vocab]);
476468
// Drop the seq axis so we get a `[1, vocab]` 2D slice (fused_sample
477469
// / argmax expect `[batch, vocab]`).
478470
let row = ffi::reshape(&row, &[1_i32, vocab]);
@@ -536,7 +528,11 @@ fn sample_block_per_position(
536528
for i in 0..n {
537529
// Row `i + 1` of the [1, L, V] logits.
538530
let row_idx = (i + 1) as i32;
539-
let row = ffi::slice(logits, &[0_i32, row_idx, 0_i32], &[1_i32, row_idx + 1, vocab]);
531+
let row = ffi::slice(
532+
logits,
533+
&[0_i32, row_idx, 0_i32],
534+
&[1_i32, row_idx + 1, vocab],
535+
);
540536
// Drop the seq axis so we get a `[1, vocab]` 2D slice (fused_sample
541537
// / argmax expect `[batch, vocab]`).
542538
let row = ffi::reshape(&row, &[1_i32, vocab]);
@@ -588,10 +584,7 @@ fn sample_block_per_position_array(
588584
}
589585

590586
let tokens = sample_block_per_position(logits, block_size, sampler)?;
591-
Ok(ffi::from_slice_i32(
592-
&tokens,
593-
&[1, (block_size - 1) as i32],
594-
))
587+
Ok(ffi::from_slice_i32(&tokens, &[1, (block_size - 1) as i32]))
595588
}
596589

597590
#[cfg(test)]
@@ -620,10 +613,7 @@ mod tests {
620613
ffi::zeros(&[4, 4], dtype::BFLOAT16),
621614
);
622615
// A non-bf16 tensor: should pass through.
623-
weights.insert(
624-
"fc.weight".to_string(),
625-
ffi::zeros(&[4, 4], dtype::FLOAT16),
626-
);
616+
weights.insert("fc.weight".to_string(), ffi::zeros(&[4, 4], dtype::FLOAT16));
627617

628618
convert_bf16_to_f16_non_quantized(&mut weights);
629619

src/lib/mlxcel-core/src/drafter/dflash/layer.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,13 @@ impl DFlashDecoderLayer {
8383
let post_attention_layernorm_w = weights
8484
.get(&format!("{prefix}.post_attention_layernorm.weight"))
8585
.map(|w| ffi::copy(w))
86-
.ok_or_else(|| {
87-
format!("Weight not found: {prefix}.post_attention_layernorm.weight")
88-
})?;
86+
.ok_or_else(|| format!("Weight not found: {prefix}.post_attention_layernorm.weight"))?;
8987

9088
Ok(Self {
9189
self_attn,
9290
mlp,
9391
input_layernorm: RMSNorm::new(input_layernorm_w, config.rms_norm_eps),
94-
post_attention_layernorm: RMSNorm::new(
95-
post_attention_layernorm_w,
96-
config.rms_norm_eps,
97-
),
92+
post_attention_layernorm: RMSNorm::new(post_attention_layernorm_w, config.rms_norm_eps),
9893
})
9994
}
10095
}

src/lib/mlxcel-core/src/drafter/dflash/mlp.rs

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,12 @@ impl DFlashMlp {
6767
group_size: i32,
6868
bits: i32,
6969
) -> Result<Self, String> {
70-
let gate_proj = UnifiedLinear::from_weights(
71-
weights,
72-
&format!("{prefix}.gate_proj"),
73-
group_size,
74-
bits,
75-
)?;
76-
let up_proj = UnifiedLinear::from_weights(
77-
weights,
78-
&format!("{prefix}.up_proj"),
79-
group_size,
80-
bits,
81-
)?;
82-
let down_proj = UnifiedLinear::from_weights(
83-
weights,
84-
&format!("{prefix}.down_proj"),
85-
group_size,
86-
bits,
87-
)?;
70+
let gate_proj =
71+
UnifiedLinear::from_weights(weights, &format!("{prefix}.gate_proj"), group_size, bits)?;
72+
let up_proj =
73+
UnifiedLinear::from_weights(weights, &format!("{prefix}.up_proj"), group_size, bits)?;
74+
let down_proj =
75+
UnifiedLinear::from_weights(weights, &format!("{prefix}.down_proj"), group_size, bits)?;
8876
Ok(Self {
8977
gate_proj,
9078
up_proj,

0 commit comments

Comments
 (0)