Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,8 @@ namespace {
auto one = array(1.0f);
auto erf_val = mlx::core::erf(mlx::core::divide(x, sqrt2));
auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val));
return {mlx::core::multiply(x, scale)};
auto result = mlx::core::multiply(x, scale);
return {mlx::core::astype(result, x.dtype())};
};
return mlx::core::compile(fn, true);
}
Expand All @@ -1233,7 +1234,8 @@ namespace {
auto one = array(1.0f);
auto erf_val = mlx::core::erf(mlx::core::divide(x, sqrt2));
auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val));
return {mlx::core::multiply(x, scale)};
auto result = mlx::core::multiply(x, scale);
return {mlx::core::astype(result, x.dtype())};
};
return mlx::core::compile(fn, true);
}
Expand All @@ -1259,7 +1261,8 @@ namespace {
auto erf_val = mlx::core::erf(mlx::core::divide(gate, sqrt2));
auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val));
auto gelu_gate = mlx::core::multiply(gate, scale);
return {mlx::core::multiply(gelu_gate, x)};
auto result = mlx::core::multiply(gelu_gate, x);
return {mlx::core::astype(result, x.dtype())};
};
return mlx::core::compile(fn, true);
}
Expand All @@ -1281,7 +1284,8 @@ namespace {
auto fn = [](const std::vector<array>& inputs) -> std::vector<array> {
const auto& gate = inputs[0];
const auto& x = inputs[1];
return {mlx::core::multiply(gelu_tanh_approx(gate), x)};
auto result = mlx::core::multiply(gelu_tanh_approx(gate), x);
return {mlx::core::astype(result, x.dtype())};
};
return mlx::core::compile(fn, true);
}
Expand Down Expand Up @@ -1324,7 +1328,8 @@ namespace {
auto one = array(1.0f);
auto erf_val = mlx::core::erf(mlx::core::divide(zeroed, sqrt2));
auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val));
return {mlx::core::multiply(zeroed, scale)};
auto result = mlx::core::multiply(zeroed, scale);
return {mlx::core::astype(result, x.dtype())};
};
return mlx::core::compile(fn, true);
}
Expand All @@ -1349,7 +1354,8 @@ namespace {
const auto& cap = inputs[1];
auto scaled = mlx::core::divide(scores, cap);
auto tanhed = mlx::core::tanh(scaled);
return {mlx::core::multiply(tanhed, cap)};
auto result = mlx::core::multiply(tanhed, cap);
return {mlx::core::astype(result, scores.dtype())};
};
return mlx::core::compile(fn, true);
}
Expand Down Expand Up @@ -1428,7 +1434,8 @@ namespace {
auto probs = mlx::core::softmax(scores, -1);

// probs @ V
return {mlx::core::matmul(probs, v)};
auto result = mlx::core::matmul(probs, v);
return {mlx::core::astype(result, v.dtype())};
};
return mlx::core::compile(fn, true); // shapeless=true
}
Expand All @@ -1450,7 +1457,8 @@ namespace {
scores = mlx::core::multiply(mlx::core::tanh(mlx::core::divide(scores, cap_arr)), cap_arr);
scores = mlx::core::add(scores, mask);
auto probs = mlx::core::softmax(scores, -1);
return {mlx::core::matmul(probs, v)};
auto result = mlx::core::matmul(probs, v);
return {mlx::core::astype(result, v.dtype())};
};
return mlx::core::compile(fn, true);
}
Expand Down Expand Up @@ -1541,7 +1549,7 @@ std::unique_ptr<MlxArray> compiled_softcap_sdpa_gqa(
auto probs = mlx::core::softmax(scores, -1);
auto v_grouped = mlx::core::reshape(v.inner, {B, Hk, 1, S, D});
auto ctx = mlx::core::matmul(probs, v_grouped);
auto out = mlx::core::reshape(ctx, {B, Hq, QL, D});
auto out = mlx::core::astype(mlx::core::reshape(ctx, {B, Hq, QL, D}), v.inner.dtype());
return std::make_unique<MlxArray>(std::move(out));
}

Expand Down
180 changes: 180 additions & 0 deletions src/lib/mlxcel-core/src/ffi_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,63 @@ fn test_memory_functions() {
set_wired_limit(0);
}

#[test]
fn test_scalar_helpers_preserve_bf16_and_f16_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
let x = astype(&from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), dtype);

let multiplied = multiply_scalar(&x, 2.0);
eval(&multiplied);
assert_eq!(array_dtype(&multiplied), dtype);

let divided = divide_scalar(&x, 2.0);
eval(&divided);
assert_eq!(array_dtype(&divided), dtype);
}
}

#[test]
fn test_softcap_helper_preserves_bf16_and_f16_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
let x = astype(
&from_slice_f32(&[0.0, 10.0, -10.0, 50.0, -50.0], &[1, 5]),
dtype,
);
let out = crate::utils::softcap(&x, 30.0);
eval(&out);

assert_eq!(array_shape(&out), vec![1, 5]);
assert_eq!(array_dtype(&out), dtype);
}
}

#[test]
fn test_attention_masks_intentionally_remain_float32() {
let causal = crate::utils::create_causal_mask(2, 1);
eval(&causal);
assert_eq!(array_dtype(&causal), dtype::FLOAT32);

let windowed = crate::utils::create_causal_mask_with_window(4, 0, Some(2));
eval(&windowed);
assert_eq!(array_dtype(&windowed), dtype::FLOAT32);

let padded = crate::utils::create_padded_prefill_mask(2, 4, 0);
eval(&padded);
assert_eq!(array_dtype(&padded), dtype::FLOAT32);
}

#[test]
fn test_clip_residual_f16_widens_and_returns_f16() {
let x = astype(&from_slice_f32(&[65500.0, 1.0], &[1, 2]), dtype::FLOAT16);
let y = astype(&from_slice_f32(&[10.0, 2.0], &[1, 2]), dtype::FLOAT16);

let out = crate::utils::clip_residual_f16(&x, &y);
eval(&out);

assert_eq!(array_shape(&out), vec![1, 2]);
assert_eq!(array_dtype(&out), dtype::FLOAT16);
}

#[test]
fn test_compiled_gelu() {
let x = from_slice_f32(&[0.0, 1.0, -1.0, 2.0], &[1, 4]);
Expand Down Expand Up @@ -702,6 +759,42 @@ fn test_compiled_gelu_approx() {
assert!(item_f32(&total) > 0.0);
}

#[test]
fn test_compiled_gelu_preserves_bf16_and_f16_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
let x = astype(&from_slice_f32(&[0.0, 1.0, -1.0, 2.0], &[1, 4]), dtype);
let out = compiled_gelu(&x);
eval(&out);

assert_eq!(array_shape(&out), vec![1, 4]);
assert_eq!(array_dtype(&out), dtype);
}
}

#[test]
fn test_compiled_gelu_approx_preserves_bf16_and_f16_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
let x = astype(&from_slice_f32(&[0.0, 1.0, -1.0, 2.0], &[1, 4]), dtype);
let out = compiled_gelu_approx(&x);
eval(&out);

assert_eq!(array_shape(&out), vec![1, 4]);
assert_eq!(array_dtype(&out), dtype);
}
}

#[test]
fn test_compiled_gelu_topk_preserves_bf16_and_f16_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
let x = astype(&from_slice_f32(&[-2.0, -1.0, 0.5, 4.0], &[1, 4]), dtype);
let out = compiled_gelu_topk(&x, 1.0);
eval(&out);

assert_eq!(array_shape(&out), vec![1, 4]);
assert_eq!(array_dtype(&out), dtype);
}
}

#[test]
fn test_gelu_approx_bf16_negative_values() {
// Verify gelu_approx does not produce NaN for negative bf16 inputs.
Expand Down Expand Up @@ -763,6 +856,47 @@ fn test_compiled_geglu_activation() {
assert!(item_f32(&total) > 0.0);
}

#[test]
fn test_compiled_geglu_preserves_bf16_and_f16_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
let gate = astype(&from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), dtype);
let x = astype(&from_slice_f32(&[0.5, 1.0, 1.5, 2.0], &[1, 4]), dtype);
let out = compiled_geglu_activation(&gate, &x);
eval(&out);

assert_eq!(array_shape(&out), vec![1, 4]);
assert_eq!(array_dtype(&out), dtype);
}
}

#[test]
fn test_compiled_geglu_approx_preserves_bf16_and_f16_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
let gate = astype(&from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), dtype);
let x = astype(&from_slice_f32(&[0.5, 1.0, 1.5, 2.0], &[1, 4]), dtype);
let out = compiled_geglu_approx_activation(&gate, &x);
eval(&out);

assert_eq!(array_shape(&out), vec![1, 4]);
assert_eq!(array_dtype(&out), dtype);
}
}

#[test]
fn test_gegelu_preserves_bf16_and_f16_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
let x = astype(
&from_slice_f32(&[-1.0, 0.5, 2.0, 3.0, -0.5, 1.0, 4.0, 5.0], &[1, 8]),
dtype,
);
let out = crate::utils::gegelu(&x, 7.0);
eval(&out);

assert_eq!(array_shape(&out), vec![1, 4]);
assert_eq!(array_dtype(&out), dtype);
}
}

#[test]
fn test_compiled_geglu_matches_manual() {
// compiled_geglu_activation(gate, x) == gelu(gate) * x
Expand Down Expand Up @@ -827,6 +961,21 @@ fn test_compiled_softcap_zero_input() {
assert!((item_f32(&out)).abs() < 1e-5, "softcap(0) should be 0");
}

#[test]
fn test_compiled_softcap_preserves_bf16_and_f16_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
let scores = astype(
&from_slice_f32(&[0.0, 10.0, -10.0, 50.0, -50.0], &[1, 5]),
dtype,
);
let out = compiled_softcap(&scores, 30.0);
eval(&out);

assert_eq!(array_shape(&out), vec![1, 5]);
assert_eq!(array_dtype(&out), dtype);
}
}

#[test]
fn test_compiled_clip_residual() {
let x = from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]);
Expand Down Expand Up @@ -863,6 +1012,21 @@ fn test_compiled_softcap_sdpa_shape() {
assert_eq!(array_shape(&out), vec![1, 2, 4, 8]);
}

#[test]
fn test_compiled_softcap_sdpa_preserves_v_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
let q = astype(&ones(&[1, 2, 4, 8], dtype::FLOAT32), dtype);
let k = astype(&ones(&[1, 2, 4, 8], dtype::FLOAT32), dtype);
let v = astype(&ones(&[1, 2, 4, 8], dtype::FLOAT32), dtype);

let out = unsafe { compiled_softcap_sdpa(&q, &k, &v, 0.125, 30.0, std::ptr::null()) };
eval(&out);

assert_eq!(array_shape(&out), vec![1, 2, 4, 8]);
assert_eq!(array_dtype(&out), dtype);
}
}

#[test]
fn test_compiled_softcap_sdpa_gqa_shape() {
// Verify compiled_softcap_sdpa_gqa: Q has n_heads, K/V have n_kv_heads
Expand All @@ -878,6 +1042,22 @@ fn test_compiled_softcap_sdpa_gqa_shape() {
assert_eq!(array_shape(&out), vec![1, 4, 2, 8]);
}

#[test]
fn test_compiled_softcap_sdpa_gqa_preserves_v_dtype() {
for dtype in [dtype::BFLOAT16, dtype::FLOAT16] {
let q = astype(&ones(&[1, 4, 2, 8], dtype::FLOAT32), dtype);
let k = astype(&ones(&[1, 2, 2, 8], dtype::FLOAT32), dtype);
let v = astype(&ones(&[1, 2, 2, 8], dtype::FLOAT32), dtype);

let out =
unsafe { compiled_softcap_sdpa_gqa(&q, &k, &v, 0.125, 30.0, 2, std::ptr::null()) };
eval(&out);

assert_eq!(array_shape(&out), vec![1, 4, 2, 8]);
assert_eq!(array_dtype(&out), dtype);
}
}

#[test]
fn test_unified_linear_quantized_weight_accessor() {
use crate::layers::{QuantizedWeight, UnifiedLinear};
Expand Down
25 changes: 19 additions & 6 deletions src/lib/mlxcel-core/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ pub fn create_causal_mask(size: i32, offset: i32) -> UniquePtr<MlxArray> {

// Convert to attention mask format: where mask=1 -> 0, where mask=0 -> -inf
// Use where_cond to avoid NaN from 0 * -inf
// Intentional FP32: additive attention masks carry 0/-inf sentinels and are
// added to attention scores, not propagated as model activations.
let zeros = ffi::zeros(&[size, total_len], dtype::FLOAT32);
let neg_inf = ffi::full_f32(&[size, total_len], f32::NEG_INFINITY, dtype::FLOAT32);
let bool_mask = ffi::greater(&mask, &zeros); // mask > 0 gives bool mask
Expand Down Expand Up @@ -211,6 +213,8 @@ pub fn create_causal_mask_with_window(

// Convert to attention mask format: where mask=1 -> 0, where mask=0 -> -inf
// Use where_cond to avoid NaN from 0 * -inf
// Intentional FP32: additive attention masks carry 0/-inf sentinels and are
// added to attention scores, not propagated as model activations.
let zeros = ffi::zeros(&[size, total_len], dtype::FLOAT32);
let neg_inf = ffi::full_f32(&[size, total_len], f32::NEG_INFINITY, dtype::FLOAT32);
let bool_mask = ffi::greater(&mask, &zeros); // mask > 0 gives bool mask
Expand Down Expand Up @@ -342,6 +346,7 @@ pub fn gelu_approx(x: &MlxArray) -> UniquePtr<MlxArray> {
/// * `x` - Input array where last dim will be split into interleaved gelu/linear parts
/// * `limit` - Clipping limit for numerical stability
pub fn gegelu(x: &MlxArray, limit: f32) -> UniquePtr<MlxArray> {
let x_dtype = ffi::array_dtype(x);
let shape = ffi::array_shape(x);
let ndim = shape.len();
let last_dim = shape[ndim - 1];
Expand Down Expand Up @@ -373,22 +378,27 @@ pub fn gegelu(x: &MlxArray, limit: f32) -> UniquePtr<MlxArray> {
let linear_part = ffi::squeeze_axis(&linear_part, ndim as i32);

// Clip both parts for numerical stability
let neg_limit = ffi::full_f32(&[1], -limit, dtype::FLOAT32);
let pos_limit = ffi::full_f32(&[1], limit, dtype::FLOAT32);
let neg_limit = ffi::full_f32(&[1], -limit, x_dtype);
let pos_limit = ffi::full_f32(&[1], limit, x_dtype);

let a_gelu = ffi::clip(&gelu_part, &neg_limit, &pos_limit);
let a_linear = ffi::clip(&linear_part, &neg_limit, &pos_limit);

// Apply GELU approximation: x * sigmoid(1.702 * x)
let coef = ffi::full_f32(&[1], 1.702, dtype::FLOAT32);
let coef = ffi::full_f32(&[1], 1.702, x_dtype);
let scaled = ffi::multiply(&coef, &a_gelu);
let sigmoid_x = ffi::sigmoid(&scaled);
let out_gelu = ffi::multiply(&a_gelu, &sigmoid_x);

// Compute: out_gelu * (a_linear + 1.0)
let ones = ffi::full_f32(&[1], 1.0, dtype::FLOAT32);
let ones = ffi::full_f32(&[1], 1.0, x_dtype);
let linear_plus_one = ffi::add(&a_linear, &ones);
ffi::multiply(&out_gelu, &linear_plus_one)
let out = ffi::multiply(&out_gelu, &linear_plus_one);
if ffi::array_dtype(&out) == x_dtype {
out
} else {
ffi::astype(&out, x_dtype)
}
}

// Gemma-specific Functions.
Expand Down Expand Up @@ -431,7 +441,8 @@ pub fn clip_residual_f16(x: &MlxArray, y: &MlxArray) -> UniquePtr<MlxArray> {
// float16 max is approximately 65504
let bound = 65504.0f32;

// Cast to f32
// Intentional FP32: the residual is widened only for overflow-safe clipping
// and is cast back to f16 before returning.
let x_f32 = ffi::astype(x, dtype::FLOAT32);
let y_f32 = ffi::astype(y, dtype::FLOAT32);

Expand Down Expand Up @@ -522,6 +533,8 @@ pub fn create_padded_prefill_mask(
let combined = ffi::multiply(&causal, &valid_mask);

// Convert to additive mask: 1 → 0.0, 0 → -inf
// Intentional FP32: additive attention masks carry 0/-inf sentinels and are
// added to attention scores, not propagated as model activations.
let zeros = ffi::zeros(&[padded_len, total_kv], dtype::FLOAT32);
let neg_inf = ffi::full_f32(&[padded_len, total_kv], f32::NEG_INFINITY, dtype::FLOAT32);
let bool_mask = ffi::greater(&combined, &zeros);
Expand Down
Loading