Skip to content

Commit 32aa7c6

Browse files
committed
fix(perf): prevent fp32 promotion in model hot paths (#20)
1 parent 9e55e2c commit 32aa7c6

27 files changed

Lines changed: 429 additions & 186 deletions

src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,8 @@ namespace {
12071207
auto one = array(1.0f);
12081208
auto erf_val = mlx::core::erf(mlx::core::divide(x, sqrt2));
12091209
auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val));
1210-
return {mlx::core::multiply(x, scale)};
1210+
auto result = mlx::core::multiply(x, scale);
1211+
return {mlx::core::astype(result, x.dtype())};
12111212
};
12121213
return mlx::core::compile(fn, true);
12131214
}
@@ -1233,7 +1234,8 @@ namespace {
12331234
auto one = array(1.0f);
12341235
auto erf_val = mlx::core::erf(mlx::core::divide(x, sqrt2));
12351236
auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val));
1236-
return {mlx::core::multiply(x, scale)};
1237+
auto result = mlx::core::multiply(x, scale);
1238+
return {mlx::core::astype(result, x.dtype())};
12371239
};
12381240
return mlx::core::compile(fn, true);
12391241
}
@@ -1259,7 +1261,8 @@ namespace {
12591261
auto erf_val = mlx::core::erf(mlx::core::divide(gate, sqrt2));
12601262
auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val));
12611263
auto gelu_gate = mlx::core::multiply(gate, scale);
1262-
return {mlx::core::multiply(gelu_gate, x)};
1264+
auto result = mlx::core::multiply(gelu_gate, x);
1265+
return {mlx::core::astype(result, x.dtype())};
12631266
};
12641267
return mlx::core::compile(fn, true);
12651268
}
@@ -1281,7 +1284,8 @@ namespace {
12811284
auto fn = [](const std::vector<array>& inputs) -> std::vector<array> {
12821285
const auto& gate = inputs[0];
12831286
const auto& x = inputs[1];
1284-
return {mlx::core::multiply(gelu_tanh_approx(gate), x)};
1287+
auto result = mlx::core::multiply(gelu_tanh_approx(gate), x);
1288+
return {mlx::core::astype(result, x.dtype())};
12851289
};
12861290
return mlx::core::compile(fn, true);
12871291
}
@@ -1324,7 +1328,8 @@ namespace {
13241328
auto one = array(1.0f);
13251329
auto erf_val = mlx::core::erf(mlx::core::divide(zeroed, sqrt2));
13261330
auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val));
1327-
return {mlx::core::multiply(zeroed, scale)};
1331+
auto result = mlx::core::multiply(zeroed, scale);
1332+
return {mlx::core::astype(result, x.dtype())};
13281333
};
13291334
return mlx::core::compile(fn, true);
13301335
}
@@ -1349,7 +1354,8 @@ namespace {
13491354
const auto& cap = inputs[1];
13501355
auto scaled = mlx::core::divide(scores, cap);
13511356
auto tanhed = mlx::core::tanh(scaled);
1352-
return {mlx::core::multiply(tanhed, cap)};
1357+
auto result = mlx::core::multiply(tanhed, cap);
1358+
return {mlx::core::astype(result, scores.dtype())};
13531359
};
13541360
return mlx::core::compile(fn, true);
13551361
}
@@ -1428,7 +1434,8 @@ namespace {
14281434
auto probs = mlx::core::softmax(scores, -1);
14291435

14301436
// probs @ V
1431-
return {mlx::core::matmul(probs, v)};
1437+
auto result = mlx::core::matmul(probs, v);
1438+
return {mlx::core::astype(result, v.dtype())};
14321439
};
14331440
return mlx::core::compile(fn, true); // shapeless=true
14341441
}
@@ -1450,7 +1457,8 @@ namespace {
14501457
scores = mlx::core::multiply(mlx::core::tanh(mlx::core::divide(scores, cap_arr)), cap_arr);
14511458
scores = mlx::core::add(scores, mask);
14521459
auto probs = mlx::core::softmax(scores, -1);
1453-
return {mlx::core::matmul(probs, v)};
1460+
auto result = mlx::core::matmul(probs, v);
1461+
return {mlx::core::astype(result, v.dtype())};
14541462
};
14551463
return mlx::core::compile(fn, true);
14561464
}
@@ -1541,7 +1549,7 @@ std::unique_ptr<MlxArray> compiled_softcap_sdpa_gqa(
15411549
auto probs = mlx::core::softmax(scores, -1);
15421550
auto v_grouped = mlx::core::reshape(v.inner, {B, Hk, 1, S, D});
15431551
auto ctx = mlx::core::matmul(probs, v_grouped);
1544-
auto out = mlx::core::reshape(ctx, {B, Hq, QL, D});
1552+
auto out = mlx::core::astype(mlx::core::reshape(ctx, {B, Hq, QL, D}), v.inner.dtype());
15451553
return std::make_unique<MlxArray>(std::move(out));
15461554
}
15471555

src/lib/mlxcel-core/src/ffi_tests.rs

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,63 @@ fn test_memory_functions() {
671671
set_wired_limit(0);
672672
}
673673

674+
#[test]
675+
fn test_scalar_helpers_preserve_bf16_and_f16_dtype() {
676+
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
677+
let x = astype(&from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), dtype);
678+
679+
let multiplied = multiply_scalar(&x, 2.0);
680+
eval(&multiplied);
681+
assert_eq!(array_dtype(&multiplied), dtype);
682+
683+
let divided = divide_scalar(&x, 2.0);
684+
eval(&divided);
685+
assert_eq!(array_dtype(&divided), dtype);
686+
}
687+
}
688+
689+
#[test]
690+
fn test_softcap_helper_preserves_bf16_and_f16_dtype() {
691+
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
692+
let x = astype(
693+
&from_slice_f32(&[0.0, 10.0, -10.0, 50.0, -50.0], &[1, 5]),
694+
dtype,
695+
);
696+
let out = crate::utils::softcap(&x, 30.0);
697+
eval(&out);
698+
699+
assert_eq!(array_shape(&out), vec![1, 5]);
700+
assert_eq!(array_dtype(&out), dtype);
701+
}
702+
}
703+
704+
#[test]
705+
fn test_attention_masks_intentionally_remain_float32() {
706+
let causal = crate::utils::create_causal_mask(2, 1);
707+
eval(&causal);
708+
assert_eq!(array_dtype(&causal), dtype::FLOAT32);
709+
710+
let windowed = crate::utils::create_causal_mask_with_window(4, 0, Some(2));
711+
eval(&windowed);
712+
assert_eq!(array_dtype(&windowed), dtype::FLOAT32);
713+
714+
let padded = crate::utils::create_padded_prefill_mask(2, 4, 0);
715+
eval(&padded);
716+
assert_eq!(array_dtype(&padded), dtype::FLOAT32);
717+
}
718+
719+
#[test]
720+
fn test_clip_residual_f16_widens_and_returns_f16() {
721+
let x = astype(&from_slice_f32(&[65500.0, 1.0], &[1, 2]), dtype::FLOAT16);
722+
let y = astype(&from_slice_f32(&[10.0, 2.0], &[1, 2]), dtype::FLOAT16);
723+
724+
let out = crate::utils::clip_residual_f16(&x, &y);
725+
eval(&out);
726+
727+
assert_eq!(array_shape(&out), vec![1, 2]);
728+
assert_eq!(array_dtype(&out), dtype::FLOAT16);
729+
}
730+
674731
#[test]
675732
fn test_compiled_gelu() {
676733
let x = from_slice_f32(&[0.0, 1.0, -1.0, 2.0], &[1, 4]);
@@ -702,6 +759,42 @@ fn test_compiled_gelu_approx() {
702759
assert!(item_f32(&total) > 0.0);
703760
}
704761

762+
#[test]
763+
fn test_compiled_gelu_preserves_bf16_and_f16_dtype() {
764+
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
765+
let x = astype(&from_slice_f32(&[0.0, 1.0, -1.0, 2.0], &[1, 4]), dtype);
766+
let out = compiled_gelu(&x);
767+
eval(&out);
768+
769+
assert_eq!(array_shape(&out), vec![1, 4]);
770+
assert_eq!(array_dtype(&out), dtype);
771+
}
772+
}
773+
774+
#[test]
775+
fn test_compiled_gelu_approx_preserves_bf16_and_f16_dtype() {
776+
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
777+
let x = astype(&from_slice_f32(&[0.0, 1.0, -1.0, 2.0], &[1, 4]), dtype);
778+
let out = compiled_gelu_approx(&x);
779+
eval(&out);
780+
781+
assert_eq!(array_shape(&out), vec![1, 4]);
782+
assert_eq!(array_dtype(&out), dtype);
783+
}
784+
}
785+
786+
#[test]
787+
fn test_compiled_gelu_topk_preserves_bf16_and_f16_dtype() {
788+
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
789+
let x = astype(&from_slice_f32(&[-2.0, -1.0, 0.5, 4.0], &[1, 4]), dtype);
790+
let out = compiled_gelu_topk(&x, 1.0);
791+
eval(&out);
792+
793+
assert_eq!(array_shape(&out), vec![1, 4]);
794+
assert_eq!(array_dtype(&out), dtype);
795+
}
796+
}
797+
705798
#[test]
706799
fn test_gelu_approx_bf16_negative_values() {
707800
// Verify gelu_approx does not produce NaN for negative bf16 inputs.
@@ -763,6 +856,47 @@ fn test_compiled_geglu_activation() {
763856
assert!(item_f32(&total) > 0.0);
764857
}
765858

859+
#[test]
860+
fn test_compiled_geglu_preserves_bf16_and_f16_dtype() {
861+
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
862+
let gate = astype(&from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), dtype);
863+
let x = astype(&from_slice_f32(&[0.5, 1.0, 1.5, 2.0], &[1, 4]), dtype);
864+
let out = compiled_geglu_activation(&gate, &x);
865+
eval(&out);
866+
867+
assert_eq!(array_shape(&out), vec![1, 4]);
868+
assert_eq!(array_dtype(&out), dtype);
869+
}
870+
}
871+
872+
#[test]
873+
fn test_compiled_geglu_approx_preserves_bf16_and_f16_dtype() {
874+
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
875+
let gate = astype(&from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), dtype);
876+
let x = astype(&from_slice_f32(&[0.5, 1.0, 1.5, 2.0], &[1, 4]), dtype);
877+
let out = compiled_geglu_approx_activation(&gate, &x);
878+
eval(&out);
879+
880+
assert_eq!(array_shape(&out), vec![1, 4]);
881+
assert_eq!(array_dtype(&out), dtype);
882+
}
883+
}
884+
885+
#[test]
886+
fn test_gegelu_preserves_bf16_and_f16_dtype() {
887+
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
888+
let x = astype(
889+
&from_slice_f32(&[-1.0, 0.5, 2.0, 3.0, -0.5, 1.0, 4.0, 5.0], &[1, 8]),
890+
dtype,
891+
);
892+
let out = crate::utils::gegelu(&x, 7.0);
893+
eval(&out);
894+
895+
assert_eq!(array_shape(&out), vec![1, 4]);
896+
assert_eq!(array_dtype(&out), dtype);
897+
}
898+
}
899+
766900
#[test]
767901
fn test_compiled_geglu_matches_manual() {
768902
// compiled_geglu_activation(gate, x) == gelu(gate) * x
@@ -827,6 +961,21 @@ fn test_compiled_softcap_zero_input() {
827961
assert!((item_f32(&out)).abs() < 1e-5, "softcap(0) should be 0");
828962
}
829963

964+
#[test]
965+
fn test_compiled_softcap_preserves_bf16_and_f16_dtype() {
966+
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
967+
let scores = astype(
968+
&from_slice_f32(&[0.0, 10.0, -10.0, 50.0, -50.0], &[1, 5]),
969+
dtype,
970+
);
971+
let out = compiled_softcap(&scores, 30.0);
972+
eval(&out);
973+
974+
assert_eq!(array_shape(&out), vec![1, 5]);
975+
assert_eq!(array_dtype(&out), dtype);
976+
}
977+
}
978+
830979
#[test]
831980
fn test_compiled_clip_residual() {
832981
let x = from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]);
@@ -863,6 +1012,21 @@ fn test_compiled_softcap_sdpa_shape() {
8631012
assert_eq!(array_shape(&out), vec![1, 2, 4, 8]);
8641013
}
8651014

1015+
#[test]
1016+
fn test_compiled_softcap_sdpa_preserves_v_dtype() {
1017+
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
1018+
let q = astype(&ones(&[1, 2, 4, 8], dtype::FLOAT32), dtype);
1019+
let k = astype(&ones(&[1, 2, 4, 8], dtype::FLOAT32), dtype);
1020+
let v = astype(&ones(&[1, 2, 4, 8], dtype::FLOAT32), dtype);
1021+
1022+
let out = unsafe { compiled_softcap_sdpa(&q, &k, &v, 0.125, 30.0, std::ptr::null()) };
1023+
eval(&out);
1024+
1025+
assert_eq!(array_shape(&out), vec![1, 2, 4, 8]);
1026+
assert_eq!(array_dtype(&out), dtype);
1027+
}
1028+
}
1029+
8661030
#[test]
8671031
fn test_compiled_softcap_sdpa_gqa_shape() {
8681032
// Verify compiled_softcap_sdpa_gqa: Q has n_heads, K/V have n_kv_heads
@@ -878,6 +1042,22 @@ fn test_compiled_softcap_sdpa_gqa_shape() {
8781042
assert_eq!(array_shape(&out), vec![1, 4, 2, 8]);
8791043
}
8801044

1045+
#[test]
1046+
fn test_compiled_softcap_sdpa_gqa_preserves_v_dtype() {
1047+
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
1048+
let q = astype(&ones(&[1, 4, 2, 8], dtype::FLOAT32), dtype);
1049+
let k = astype(&ones(&[1, 2, 2, 8], dtype::FLOAT32), dtype);
1050+
let v = astype(&ones(&[1, 2, 2, 8], dtype::FLOAT32), dtype);
1051+
1052+
let out =
1053+
unsafe { compiled_softcap_sdpa_gqa(&q, &k, &v, 0.125, 30.0, 2, std::ptr::null()) };
1054+
eval(&out);
1055+
1056+
assert_eq!(array_shape(&out), vec![1, 4, 2, 8]);
1057+
assert_eq!(array_dtype(&out), dtype);
1058+
}
1059+
}
1060+
8811061
#[test]
8821062
fn test_unified_linear_quantized_weight_accessor() {
8831063
use crate::layers::{QuantizedWeight, UnifiedLinear};

src/lib/mlxcel-core/src/utils.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ pub fn create_causal_mask(size: i32, offset: i32) -> UniquePtr<MlxArray> {
100100

101101
// Convert to attention mask format: where mask=1 -> 0, where mask=0 -> -inf
102102
// Use where_cond to avoid NaN from 0 * -inf
103+
// Intentional FP32: additive attention masks carry 0/-inf sentinels and are
104+
// added to attention scores, not propagated as model activations.
103105
let zeros = ffi::zeros(&[size, total_len], dtype::FLOAT32);
104106
let neg_inf = ffi::full_f32(&[size, total_len], f32::NEG_INFINITY, dtype::FLOAT32);
105107
let bool_mask = ffi::greater(&mask, &zeros); // mask > 0 gives bool mask
@@ -211,6 +213,8 @@ pub fn create_causal_mask_with_window(
211213

212214
// Convert to attention mask format: where mask=1 -> 0, where mask=0 -> -inf
213215
// Use where_cond to avoid NaN from 0 * -inf
216+
// Intentional FP32: additive attention masks carry 0/-inf sentinels and are
217+
// added to attention scores, not propagated as model activations.
214218
let zeros = ffi::zeros(&[size, total_len], dtype::FLOAT32);
215219
let neg_inf = ffi::full_f32(&[size, total_len], f32::NEG_INFINITY, dtype::FLOAT32);
216220
let bool_mask = ffi::greater(&mask, &zeros); // mask > 0 gives bool mask
@@ -342,6 +346,7 @@ pub fn gelu_approx(x: &MlxArray) -> UniquePtr<MlxArray> {
342346
/// * `x` - Input array where last dim will be split into interleaved gelu/linear parts
343347
/// * `limit` - Clipping limit for numerical stability
344348
pub fn gegelu(x: &MlxArray, limit: f32) -> UniquePtr<MlxArray> {
349+
let x_dtype = ffi::array_dtype(x);
345350
let shape = ffi::array_shape(x);
346351
let ndim = shape.len();
347352
let last_dim = shape[ndim - 1];
@@ -373,22 +378,27 @@ pub fn gegelu(x: &MlxArray, limit: f32) -> UniquePtr<MlxArray> {
373378
let linear_part = ffi::squeeze_axis(&linear_part, ndim as i32);
374379

375380
// Clip both parts for numerical stability
376-
let neg_limit = ffi::full_f32(&[1], -limit, dtype::FLOAT32);
377-
let pos_limit = ffi::full_f32(&[1], limit, dtype::FLOAT32);
381+
let neg_limit = ffi::full_f32(&[1], -limit, x_dtype);
382+
let pos_limit = ffi::full_f32(&[1], limit, x_dtype);
378383

379384
let a_gelu = ffi::clip(&gelu_part, &neg_limit, &pos_limit);
380385
let a_linear = ffi::clip(&linear_part, &neg_limit, &pos_limit);
381386

382387
// Apply GELU approximation: x * sigmoid(1.702 * x)
383-
let coef = ffi::full_f32(&[1], 1.702, dtype::FLOAT32);
388+
let coef = ffi::full_f32(&[1], 1.702, x_dtype);
384389
let scaled = ffi::multiply(&coef, &a_gelu);
385390
let sigmoid_x = ffi::sigmoid(&scaled);
386391
let out_gelu = ffi::multiply(&a_gelu, &sigmoid_x);
387392

388393
// Compute: out_gelu * (a_linear + 1.0)
389-
let ones = ffi::full_f32(&[1], 1.0, dtype::FLOAT32);
394+
let ones = ffi::full_f32(&[1], 1.0, x_dtype);
390395
let linear_plus_one = ffi::add(&a_linear, &ones);
391-
ffi::multiply(&out_gelu, &linear_plus_one)
396+
let out = ffi::multiply(&out_gelu, &linear_plus_one);
397+
if ffi::array_dtype(&out) == x_dtype {
398+
out
399+
} else {
400+
ffi::astype(&out, x_dtype)
401+
}
392402
}
393403

394404
// Gemma-specific Functions.
@@ -431,7 +441,8 @@ pub fn clip_residual_f16(x: &MlxArray, y: &MlxArray) -> UniquePtr<MlxArray> {
431441
// float16 max is approximately 65504
432442
let bound = 65504.0f32;
433443

434-
// Cast to f32
444+
// Intentional FP32: the residual is widened only for overflow-safe clipping
445+
// and is cast back to f16 before returning.
435446
let x_f32 = ffi::astype(x, dtype::FLOAT32);
436447
let y_f32 = ffi::astype(y, dtype::FLOAT32);
437448

@@ -522,6 +533,8 @@ pub fn create_padded_prefill_mask(
522533
let combined = ffi::multiply(&causal, &valid_mask);
523534

524535
// Convert to additive mask: 1 → 0.0, 0 → -inf
536+
// Intentional FP32: additive attention masks carry 0/-inf sentinels and are
537+
// added to attention scores, not propagated as model activations.
525538
let zeros = ffi::zeros(&[padded_len, total_kv], dtype::FLOAT32);
526539
let neg_inf = ffi::full_f32(&[padded_len, total_kv], f32::NEG_INFINITY, dtype::FLOAT32);
527540
let bool_mask = ffi::greater(&combined, &zeros);

0 commit comments

Comments
 (0)