Skip to content

Commit 9af1446

Browse files
committed
chore: reformat tree with cargo fmt
Apply default rustfmt formatting to the entire workspace. No behavior changes — purely mechanical whitespace and line-break adjustments. Reviewers can use `git diff -w` to confirm the diff is bounded to formatting only.
1 parent cb2002b commit 9af1446

49 files changed

Lines changed: 540 additions & 791 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/lib/mlxcel-core/src/cache.rs

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2373,10 +2373,8 @@ impl KVCache {
23732373

23742374
let k_slice = ffi::slice(k_int8, &[0, 0, 0, 0], &[ks[0], ks[1], live_len, ks[3]]);
23752375
let v_slice = ffi::slice(v_int8, &[0, 0, 0, 0], &[vs[0], vs[1], live_len, vs[3]]);
2376-
let ks_slice =
2377-
ffi::slice(k_scales, &[0, 0, 0, 0], &[kss[0], kss[1], live_len, 1]);
2378-
let vs_slice =
2379-
ffi::slice(v_scales, &[0, 0, 0, 0], &[vss[0], vss[1], live_len, 1]);
2376+
let ks_slice = ffi::slice(k_scales, &[0, 0, 0, 0], &[kss[0], kss[1], live_len, 1]);
2377+
let vs_slice = ffi::slice(v_scales, &[0, 0, 0, 0], &[vss[0], vss[1], live_len, 1]);
23802378

23812379
(
23822380
dequantize(&k_slice, &ks_slice),
@@ -3785,8 +3783,7 @@ impl RotatingKVCache {
37853783
return (new_keys, new_values);
37863784
}
37873785

3788-
let (base_k, base_v, current_seq_len) =
3789-
self.visible_fp16_prefix_for_concat();
3786+
let (base_k, base_v, current_seq_len) = self.visible_fp16_prefix_for_concat();
37903787

37913788
let concat_k = concatenate(&base_k, &new_keys, 2);
37923789
let concat_v = concatenate(&base_v, &new_values, 2);
@@ -5681,10 +5678,7 @@ mod tests {
56815678
.collect::<Vec<_>>()
56825679
};
56835680
assert_eq!(to_f32(&visible_keys), vec![1.0, 5.0, 6.0, 7.0, 8.0]);
5684-
assert_eq!(
5685-
to_f32(&visible_values),
5686-
vec![10.0, 50.0, 60.0, 70.0, 80.0]
5687-
);
5681+
assert_eq!(to_f32(&visible_values), vec![10.0, 50.0, 60.0, 70.0, 80.0]);
56885682
}
56895683

56905684
#[test]
@@ -6190,10 +6184,8 @@ mod tests {
61906184
}
61916185
let (q_unrot, _) = unit_token(42);
61926186
let q_ref = rotate_at(&q_unrot, M);
6193-
let (k_ref, v_ref) = cache_ref.update_and_fetch(
6194-
rotate_at(&unit_token(99).0, M),
6195-
unit_token(99).1,
6196-
);
6187+
let (k_ref, v_ref) =
6188+
cache_ref.update_and_fetch(rotate_at(&unit_token(99).0, M), unit_token(99).1);
61976189
let out_ref = mlxcel_core::causal_attention(&q_ref, &k_ref, &v_ref, scale, 0.0, 0);
61986190
let out_ref_f32 = to_f32(&out_ref);
61996191

@@ -6216,11 +6208,10 @@ mod tests {
62166208
// position `M` to simulate the pre-fix offset decrement.
62176209
assert_eq!(cache_broken.trim_front(N), N);
62186210
let q_broken = rotate_at(&q_unrot, M);
6219-
let (k_broken, v_broken) = cache_broken.update_and_fetch(
6220-
rotate_at(&unit_token(99).0, M),
6221-
unit_token(99).1,
6222-
);
6223-
let out_broken = mlxcel_core::causal_attention(&q_broken, &k_broken, &v_broken, scale, 0.0, 0);
6211+
let (k_broken, v_broken) =
6212+
cache_broken.update_and_fetch(rotate_at(&unit_token(99).0, M), unit_token(99).1);
6213+
let out_broken =
6214+
mlxcel_core::causal_attention(&q_broken, &k_broken, &v_broken, scale, 0.0, 0);
62246215
let out_broken_f32 = to_f32(&out_broken);
62256216

62266217
let mut sq_err = 0.0_f64;

src/lib/mlxcel-core/src/cache/detach.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -241,19 +241,14 @@ impl DetachedKVCache {
241241

242242
// Slice axis 2 (seq-len) of each tensor to the new length. Width axes
243243
// (B, H, head_dim / packed_dim) come from the existing shape.
244-
let trim_axis_seq = |a: &Option<UniquePtr<MlxArray>>,
245-
tail_axis: i32|
246-
-> Option<UniquePtr<MlxArray>> {
247-
a.as_ref().map(|arr| {
248-
let shape = ffi::array_shape(arr);
249-
let last = if tail_axis == 0 { shape[3] } else { tail_axis };
250-
ffi::slice(
251-
arr,
252-
&[0, 0, 0, 0],
253-
&[shape[0], shape[1], new_len, last],
254-
)
255-
})
256-
};
244+
let trim_axis_seq =
245+
|a: &Option<UniquePtr<MlxArray>>, tail_axis: i32| -> Option<UniquePtr<MlxArray>> {
246+
a.as_ref().map(|arr| {
247+
let shape = ffi::array_shape(arr);
248+
let last = if tail_axis == 0 { shape[3] } else { tail_axis };
249+
ffi::slice(arr, &[0, 0, 0, 0], &[shape[0], shape[1], new_len, last])
250+
})
251+
};
257252

258253
if self.mode == KVCacheMode::Turbo4Delegated {
259254
// K is unified — same shape contract as Fp16. Slice to `new_len`.

src/lib/mlxcel-core/src/cache/detach_tests.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -677,10 +677,7 @@ fn detached_kvcache_trim_to_mid_buffer_resets_offset_and_visible_keys() {
677677
#[test]
678678
fn detached_kvcache_trim_to_zero_drops_storage_and_keeps_mode() {
679679
let mut live = KVCache::new_with_mode(KVCacheMode::Int8);
680-
live.update(
681-
fp32_tokens(&[1.0, 2.0, 3.0]),
682-
fp32_tokens(&[4.0, 5.0, 6.0]),
683-
);
680+
live.update(fp32_tokens(&[1.0, 2.0, 3.0]), fp32_tokens(&[4.0, 5.0, 6.0]));
684681
let mut handle = live.clone_handle();
685682
assert_eq!(handle.seq_len(), 3);
686683
assert_eq!(handle.mode(), KVCacheMode::Int8);
@@ -698,7 +695,9 @@ fn detached_kvcache_trim_to_exact_offset_is_noop() {
698695
live.update(fp32_tokens(&[1.0, 2.0]), fp32_tokens(&[3.0, 4.0]));
699696
let mut handle = live.clone_handle();
700697

701-
handle.trim_to(2).expect("trim_to exact offset must succeed");
698+
handle
699+
.trim_to(2)
700+
.expect("trim_to exact offset must succeed");
702701
assert_eq!(handle.seq_len(), 2);
703702

704703
let mut restored = KVCache::new();
@@ -825,7 +824,9 @@ fn detached_cache_set_truncate_to_zero_drops_every_layer() {
825824
let v: Vec<f32> = (1..=4).map(|i| i as f32 * 10.0).collect();
826825
let mut detached = detached_set_with_fp16_layers(2, &k, &v);
827826

828-
detached.truncate_to(0).expect("truncate_to zero must succeed");
827+
detached
828+
.truncate_to(0)
829+
.expect("truncate_to zero must succeed");
829830
assert!(detached.caches.iter().all(|c| c.is_empty()));
830831
assert_eq!(detached.current_offset, 0);
831832
assert_eq!(detached.prompt_len, 0);
@@ -940,10 +941,7 @@ fn detached_cache_set_truncate_to_int8_preserves_dequantization() {
940941
let seq_adopted = pool.adopt(&model, detached).unwrap();
941942
let (k_a, v_a) = {
942943
let caches = pool.get_caches_mut(seq_adopted).unwrap();
943-
caches[0].update_and_fetch(
944-
fp32_tokens(&full_k[5..]),
945-
fp32_tokens(&full_v[5..]),
946-
)
944+
caches[0].update_and_fetch(fp32_tokens(&full_k[5..]), fp32_tokens(&full_v[5..]))
947945
};
948946
eval(&k_a);
949947
eval(&v_a);

src/lib/mlxcel-core/src/drafter/dflash/attention.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,4 +441,3 @@ mod tests {
441441
);
442442
}
443443
}
444-

src/lib/mlxcel-core/src/drafter/dflash/drafter.rs

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,11 @@ impl DFlashDrafter {
106106
path: config_path.display().to_string(),
107107
source: e,
108108
})?;
109-
let config = DFlashConfig::from_json(&config_json).map_err(|e| {
110-
DrafterError::ConfigParse {
109+
let config =
110+
DFlashConfig::from_json(&config_json).map_err(|e| DrafterError::ConfigParse {
111111
path: config_path.display().to_string(),
112112
source: serde::de::Error::custom(e),
113-
}
114-
})?;
113+
})?;
115114

116115
let mut weights = crate::weights::load_weights_from_dir(path)
117116
.map_err(|msg| DrafterError::LoadFailed { reason: msg })?;
@@ -209,8 +208,9 @@ impl Drafter for DFlashDrafter {
209208
// expose `embed_tokens` (e.g. a non-Qwen-3.5 model fed in by
210209
// mistake).
211210
if self.model.needs_embed_binding() {
212-
let embed = target.embed_tokens_module().ok_or_else(|| {
213-
DrafterError::BindFailed {
211+
let embed = target
212+
.embed_tokens_module()
213+
.ok_or_else(|| DrafterError::BindFailed {
214214
reason: format!(
215215
"DFlash drafter checkpoint omits embed_tokens.weight \
216216
and the target does not expose embed_tokens_module(); \
@@ -220,8 +220,7 @@ impl Drafter for DFlashDrafter {
220220
(kind = {})",
221221
self.kind()
222222
),
223-
}
224-
})?;
223+
})?;
225224
self.model.bind_target_embedding(embed);
226225
} else {
227226
// Legacy capability smoke-test for self-contained checkpoints:
@@ -427,10 +426,7 @@ fn sample_block_per_position_batched(
427426
sampler: &SamplingConfig,
428427
) -> Result<Vec<Vec<i32>>, DrafterError> {
429428
let shape = ffi::array_shape(logits);
430-
if shape.len() != 3
431-
|| shape[0] != batch_size as i32
432-
|| shape[1] != block_size as i32
433-
{
429+
if shape.len() != 3 || shape[0] != batch_size as i32 || shape[1] != block_size as i32 {
434430
return Err(DrafterError::DraftFailed {
435431
reason: format!(
436432
"DFlash drafter (batched) expected logits shape \
@@ -468,11 +464,7 @@ fn sample_block_per_position_batched(
468464
for i in 0..n {
469465
// Row `(b, i+1)` of the [B, L, V] logits.
470466
let pos = (i + 1) as i32;
471-
let row = ffi::slice(
472-
logits,
473-
&[b, pos, 0_i32],
474-
&[b + 1, pos + 1, vocab],
475-
);
467+
let row = ffi::slice(logits, &[b, pos, 0_i32], &[b + 1, pos + 1, vocab]);
476468
// Drop the seq axis so we get a `[1, vocab]` 2D slice (fused_sample
477469
// / argmax expect `[batch, vocab]`).
478470
let row = ffi::reshape(&row, &[1_i32, vocab]);
@@ -536,7 +528,11 @@ fn sample_block_per_position(
536528
for i in 0..n {
537529
// Row `i + 1` of the [1, L, V] logits.
538530
let row_idx = (i + 1) as i32;
539-
let row = ffi::slice(logits, &[0_i32, row_idx, 0_i32], &[1_i32, row_idx + 1, vocab]);
531+
let row = ffi::slice(
532+
logits,
533+
&[0_i32, row_idx, 0_i32],
534+
&[1_i32, row_idx + 1, vocab],
535+
);
540536
// Drop the seq axis so we get a `[1, vocab]` 2D slice (fused_sample
541537
// / argmax expect `[batch, vocab]`).
542538
let row = ffi::reshape(&row, &[1_i32, vocab]);
@@ -588,10 +584,7 @@ fn sample_block_per_position_array(
588584
}
589585

590586
let tokens = sample_block_per_position(logits, block_size, sampler)?;
591-
Ok(ffi::from_slice_i32(
592-
&tokens,
593-
&[1, (block_size - 1) as i32],
594-
))
587+
Ok(ffi::from_slice_i32(&tokens, &[1, (block_size - 1) as i32]))
595588
}
596589

597590
#[cfg(test)]
@@ -620,10 +613,7 @@ mod tests {
620613
ffi::zeros(&[4, 4], dtype::BFLOAT16),
621614
);
622615
// A non-bf16 tensor: should pass through.
623-
weights.insert(
624-
"fc.weight".to_string(),
625-
ffi::zeros(&[4, 4], dtype::FLOAT16),
626-
);
616+
weights.insert("fc.weight".to_string(), ffi::zeros(&[4, 4], dtype::FLOAT16));
627617

628618
convert_bf16_to_f16_non_quantized(&mut weights);
629619

src/lib/mlxcel-core/src/drafter/dflash/layer.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,13 @@ impl DFlashDecoderLayer {
8383
let post_attention_layernorm_w = weights
8484
.get(&format!("{prefix}.post_attention_layernorm.weight"))
8585
.map(|w| ffi::copy(w))
86-
.ok_or_else(|| {
87-
format!("Weight not found: {prefix}.post_attention_layernorm.weight")
88-
})?;
86+
.ok_or_else(|| format!("Weight not found: {prefix}.post_attention_layernorm.weight"))?;
8987

9088
Ok(Self {
9189
self_attn,
9290
mlp,
9391
input_layernorm: RMSNorm::new(input_layernorm_w, config.rms_norm_eps),
94-
post_attention_layernorm: RMSNorm::new(
95-
post_attention_layernorm_w,
96-
config.rms_norm_eps,
97-
),
92+
post_attention_layernorm: RMSNorm::new(post_attention_layernorm_w, config.rms_norm_eps),
9893
})
9994
}
10095
}

src/lib/mlxcel-core/src/drafter/dflash/mlp.rs

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,12 @@ impl DFlashMlp {
6767
group_size: i32,
6868
bits: i32,
6969
) -> Result<Self, String> {
70-
let gate_proj = UnifiedLinear::from_weights(
71-
weights,
72-
&format!("{prefix}.gate_proj"),
73-
group_size,
74-
bits,
75-
)?;
76-
let up_proj = UnifiedLinear::from_weights(
77-
weights,
78-
&format!("{prefix}.up_proj"),
79-
group_size,
80-
bits,
81-
)?;
82-
let down_proj = UnifiedLinear::from_weights(
83-
weights,
84-
&format!("{prefix}.down_proj"),
85-
group_size,
86-
bits,
87-
)?;
70+
let gate_proj =
71+
UnifiedLinear::from_weights(weights, &format!("{prefix}.gate_proj"), group_size, bits)?;
72+
let up_proj =
73+
UnifiedLinear::from_weights(weights, &format!("{prefix}.up_proj"), group_size, bits)?;
74+
let down_proj =
75+
UnifiedLinear::from_weights(weights, &format!("{prefix}.down_proj"), group_size, bits)?;
8876
Ok(Self {
8977
gate_proj,
9078
up_proj,

src/lib/mlxcel-core/src/drafter/dflash/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ pub(crate) fn materialize_argmax_i32_vec(argmax: &MlxArray, expected_len: usize)
104104
.take(expected_len)
105105
.map(|chunk| {
106106
i64::from_ne_bytes([
107-
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
108-
chunk[7],
107+
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
109108
]) as i32
110109
})
111110
.collect(),

src/lib/mlxcel-core/src/drafter/dflash/model.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,14 @@ mod tests {
464464
fn tiny_weights_without_embed() -> WeightMap {
465465
let mut w: WeightMap = std::collections::HashMap::new();
466466
// fc: [hidden, num_target_layers * hidden] = [4, 16]
467-
w.insert("fc.weight".to_string(), ffi::zeros(&[4, 16], dtype::FLOAT32));
468-
w.insert("hidden_norm.weight".to_string(), ffi::zeros(&[4], dtype::FLOAT32));
467+
w.insert(
468+
"fc.weight".to_string(),
469+
ffi::zeros(&[4, 16], dtype::FLOAT32),
470+
);
471+
w.insert(
472+
"hidden_norm.weight".to_string(),
473+
ffi::zeros(&[4], dtype::FLOAT32),
474+
);
469475
w.insert("norm.weight".to_string(), ffi::zeros(&[4], dtype::FLOAT32));
470476
// Layer 0 projections. q out = n_heads*head_dim = 4; k/v out =
471477
// n_kv_heads*head_dim = 2; o in = 4.
@@ -593,9 +599,9 @@ mod tests {
593599

594600
// Stand-in for the target's embedding table — a regular
595601
// [vocab, hidden] tensor, the same shape Qwen 3.5 hands out.
596-
let target_embed = crate::layers::UnifiedEmbedding::Regular(
597-
crate::layers::Embedding::new(ffi::zeros(&[8, 4], dtype::FLOAT32)),
598-
);
602+
let target_embed = crate::layers::UnifiedEmbedding::Regular(crate::layers::Embedding::new(
603+
ffi::zeros(&[8, 4], dtype::FLOAT32),
604+
));
599605
model.bind_target_embedding(target_embed);
600606

601607
assert!(

0 commit comments

Comments
 (0)