diff --git a/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp b/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp index c973a3c..663da8d 100644 --- a/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp +++ b/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp @@ -1207,7 +1207,8 @@ namespace { auto one = array(1.0f); auto erf_val = mlx::core::erf(mlx::core::divide(x, sqrt2)); auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val)); - return {mlx::core::multiply(x, scale)}; + auto result = mlx::core::multiply(x, scale); + return {mlx::core::astype(result, x.dtype())}; }; return mlx::core::compile(fn, true); } @@ -1233,7 +1234,8 @@ namespace { auto one = array(1.0f); auto erf_val = mlx::core::erf(mlx::core::divide(x, sqrt2)); auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val)); - return {mlx::core::multiply(x, scale)}; + auto result = mlx::core::multiply(x, scale); + return {mlx::core::astype(result, x.dtype())}; }; return mlx::core::compile(fn, true); } @@ -1259,7 +1261,8 @@ namespace { auto erf_val = mlx::core::erf(mlx::core::divide(gate, sqrt2)); auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val)); auto gelu_gate = mlx::core::multiply(gate, scale); - return {mlx::core::multiply(gelu_gate, x)}; + auto result = mlx::core::multiply(gelu_gate, x); + return {mlx::core::astype(result, x.dtype())}; }; return mlx::core::compile(fn, true); } @@ -1281,7 +1284,8 @@ namespace { auto fn = [](const std::vector& inputs) -> std::vector { const auto& gate = inputs[0]; const auto& x = inputs[1]; - return {mlx::core::multiply(gelu_tanh_approx(gate), x)}; + auto result = mlx::core::multiply(gelu_tanh_approx(gate), x); + return {mlx::core::astype(result, x.dtype())}; }; return mlx::core::compile(fn, true); } @@ -1324,7 +1328,8 @@ namespace { auto one = array(1.0f); auto erf_val = mlx::core::erf(mlx::core::divide(zeroed, sqrt2)); auto scale = mlx::core::multiply(half, mlx::core::add(one, erf_val)); - return {mlx::core::multiply(zeroed, scale)}; + auto result = mlx::core::multiply(zeroed, scale); + return {mlx::core::astype(result, x.dtype())}; }; return mlx::core::compile(fn, true); } @@ -1349,7 +1354,8 @@ namespace { const auto& cap = inputs[1]; auto scaled = mlx::core::divide(scores, cap); auto tanhed = mlx::core::tanh(scaled); - return {mlx::core::multiply(tanhed, cap)}; + auto result = mlx::core::multiply(tanhed, cap); + return {mlx::core::astype(result, scores.dtype())}; }; return mlx::core::compile(fn, true); } @@ -1428,7 +1434,8 @@ namespace { auto probs = mlx::core::softmax(scores, -1); // probs @ V - return {mlx::core::matmul(probs, v)}; + auto result = mlx::core::matmul(probs, v); + return {mlx::core::astype(result, v.dtype())}; }; return mlx::core::compile(fn, true); // shapeless=true } @@ -1450,7 +1457,8 @@ namespace { scores = mlx::core::multiply(mlx::core::tanh(mlx::core::divide(scores, cap_arr)), cap_arr); scores = mlx::core::add(scores, mask); auto probs = mlx::core::softmax(scores, -1); - return {mlx::core::matmul(probs, v)}; + auto result = mlx::core::matmul(probs, v); + return {mlx::core::astype(result, v.dtype())}; }; return mlx::core::compile(fn, true); } @@ -1541,7 +1549,7 @@ std::unique_ptr compiled_softcap_sdpa_gqa( auto probs = mlx::core::softmax(scores, -1); auto v_grouped = mlx::core::reshape(v.inner, {B, Hk, 1, S, D}); auto ctx = mlx::core::matmul(probs, v_grouped); - auto out = mlx::core::reshape(ctx, {B, Hq, QL, D}); + auto out = mlx::core::astype(mlx::core::reshape(ctx, {B, Hq, QL, D}), v.inner.dtype()); return std::make_unique(std::move(out)); } diff --git a/src/lib/mlxcel-core/src/ffi_tests.rs b/src/lib/mlxcel-core/src/ffi_tests.rs index de558ee..a59153a 100644 --- a/src/lib/mlxcel-core/src/ffi_tests.rs +++ b/src/lib/mlxcel-core/src/ffi_tests.rs @@ -671,6 +671,63 @@ fn test_memory_functions() { set_wired_limit(0); } +#[test] +fn test_scalar_helpers_preserve_bf16_and_f16_dtype() { + for dtype in [dtype::BFLOAT16, dtype::FLOAT16] { + let x = astype(&from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), dtype); + + let multiplied = multiply_scalar(&x, 2.0); + eval(&multiplied); + assert_eq!(array_dtype(&multiplied), dtype); + + let divided = divide_scalar(&x, 2.0); + eval(÷d); + assert_eq!(array_dtype(÷d), dtype); + } +} + +#[test] +fn test_softcap_helper_preserves_bf16_and_f16_dtype() { + for dtype in [dtype::BFLOAT16, dtype::FLOAT16] { + let x = astype( + &from_slice_f32(&[0.0, 10.0, -10.0, 50.0, -50.0], &[1, 5]), + dtype, + ); + let out = crate::utils::softcap(&x, 30.0); + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 5]); + assert_eq!(array_dtype(&out), dtype); + } +} + +#[test] +fn test_attention_masks_intentionally_remain_float32() { + let causal = crate::utils::create_causal_mask(2, 1); + eval(&causal); + assert_eq!(array_dtype(&causal), dtype::FLOAT32); + + let windowed = crate::utils::create_causal_mask_with_window(4, 0, Some(2)); + eval(&windowed); + assert_eq!(array_dtype(&windowed), dtype::FLOAT32); + + let padded = crate::utils::create_padded_prefill_mask(2, 4, 0); + eval(&padded); + assert_eq!(array_dtype(&padded), dtype::FLOAT32); +} + +#[test] +fn test_clip_residual_f16_widens_and_returns_f16() { + let x = astype(&from_slice_f32(&[65500.0, 1.0], &[1, 2]), dtype::FLOAT16); + let y = astype(&from_slice_f32(&[10.0, 2.0], &[1, 2]), dtype::FLOAT16); + + let out = crate::utils::clip_residual_f16(&x, &y); + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 2]); + assert_eq!(array_dtype(&out), dtype::FLOAT16); +} + #[test] fn test_compiled_gelu() { let x = from_slice_f32(&[0.0, 1.0, -1.0, 2.0], &[1, 4]); @@ -702,6 +759,42 @@ fn test_compiled_gelu_approx() { assert!(item_f32(&total) > 0.0); } +#[test] +fn test_compiled_gelu_preserves_bf16_and_f16_dtype() { + for dtype in [dtype::BFLOAT16, dtype::FLOAT16] { + let x = astype(&from_slice_f32(&[0.0, 1.0, -1.0, 2.0], &[1, 4]), dtype); + let out = compiled_gelu(&x); + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 4]); + assert_eq!(array_dtype(&out), dtype); + } +} + +#[test] +fn test_compiled_gelu_approx_preserves_bf16_and_f16_dtype() { + for dtype in [dtype::BFLOAT16, dtype::FLOAT16] { + let x = astype(&from_slice_f32(&[0.0, 1.0, -1.0, 2.0], &[1, 4]), dtype); + let out = compiled_gelu_approx(&x); + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 4]); + assert_eq!(array_dtype(&out), dtype); + } +} + +#[test] +fn test_compiled_gelu_topk_preserves_bf16_and_f16_dtype() { + for dtype in [dtype::BFLOAT16, dtype::FLOAT16] { + let x = astype(&from_slice_f32(&[-2.0, -1.0, 0.5, 4.0], &[1, 4]), dtype); + let out = compiled_gelu_topk(&x, 1.0); + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 4]); + assert_eq!(array_dtype(&out), dtype); + } +} + #[test] fn test_gelu_approx_bf16_negative_values() { // Verify gelu_approx does not produce NaN for negative bf16 inputs. @@ -763,6 +856,47 @@ fn test_compiled_geglu_activation() { assert!(item_f32(&total) > 0.0); } +#[test] +fn test_compiled_geglu_preserves_bf16_and_f16_dtype() { + for dtype in [dtype::BFLOAT16, dtype::FLOAT16] { + let gate = astype(&from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), dtype); + let x = astype(&from_slice_f32(&[0.5, 1.0, 1.5, 2.0], &[1, 4]), dtype); + let out = compiled_geglu_activation(&gate, &x); + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 4]); + assert_eq!(array_dtype(&out), dtype); + } +} + +#[test] +fn test_compiled_geglu_approx_preserves_bf16_and_f16_dtype() { + for dtype in [dtype::BFLOAT16, dtype::FLOAT16] { + let gate = astype(&from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), dtype); + let x = astype(&from_slice_f32(&[0.5, 1.0, 1.5, 2.0], &[1, 4]), dtype); + let out = compiled_geglu_approx_activation(&gate, &x); + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 4]); + assert_eq!(array_dtype(&out), dtype); + } +} + +#[test] +fn test_gegelu_preserves_bf16_and_f16_dtype() { + for dtype in [dtype::BFLOAT16, dtype::FLOAT16] { + let x = astype( + &from_slice_f32(&[-1.0, 0.5, 2.0, 3.0, -0.5, 1.0, 4.0, 5.0], &[1, 8]), + dtype, + ); + let out = crate::utils::gegelu(&x, 7.0); + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 4]); + assert_eq!(array_dtype(&out), dtype); + } +} + #[test] fn test_compiled_geglu_matches_manual() { // compiled_geglu_activation(gate, x) == gelu(gate) * x @@ -827,6 +961,21 @@ fn test_compiled_softcap_zero_input() { assert!((item_f32(&out)).abs() < 1e-5, "softcap(0) should be 0"); } +#[test] +fn test_compiled_softcap_preserves_bf16_and_f16_dtype() { + for dtype in [dtype::BFLOAT16, dtype::FLOAT16] { + let scores = astype( + &from_slice_f32(&[0.0, 10.0, -10.0, 50.0, -50.0], &[1, 5]), + dtype, + ); + let out = compiled_softcap(&scores, 30.0); + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 5]); + assert_eq!(array_dtype(&out), dtype); + } +} + #[test] fn test_compiled_clip_residual() { let x = from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]); @@ -863,6 +1012,21 @@ fn test_compiled_softcap_sdpa_shape() { assert_eq!(array_shape(&out), vec![1, 2, 4, 8]); } +#[test] +fn test_compiled_softcap_sdpa_preserves_v_dtype() { + for dtype in [dtype::BFLOAT16, dtype::FLOAT16] { + let q = astype(&ones(&[1, 2, 4, 8], dtype::FLOAT32), dtype); + let k = astype(&ones(&[1, 2, 4, 8], dtype::FLOAT32), dtype); + let v = astype(&ones(&[1, 2, 4, 8], dtype::FLOAT32), dtype); + + let out = unsafe { compiled_softcap_sdpa(&q, &k, &v, 0.125, 30.0, std::ptr::null()) }; + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 2, 4, 8]); + assert_eq!(array_dtype(&out), dtype); + } +} + #[test] fn test_compiled_softcap_sdpa_gqa_shape() { // Verify compiled_softcap_sdpa_gqa: Q has n_heads, K/V have n_kv_heads @@ -878,6 +1042,22 @@ fn test_compiled_softcap_sdpa_gqa_shape() { assert_eq!(array_shape(&out), vec![1, 4, 2, 8]); } +#[test] +fn test_compiled_softcap_sdpa_gqa_preserves_v_dtype() { + for dtype in [dtype::BFLOAT16, dtype::FLOAT16] { + let q = astype(&ones(&[1, 4, 2, 8], dtype::FLOAT32), dtype); + let k = astype(&ones(&[1, 2, 2, 8], dtype::FLOAT32), dtype); + let v = astype(&ones(&[1, 2, 2, 8], dtype::FLOAT32), dtype); + + let out = + unsafe { compiled_softcap_sdpa_gqa(&q, &k, &v, 0.125, 30.0, 2, std::ptr::null()) }; + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 4, 2, 8]); + assert_eq!(array_dtype(&out), dtype); + } +} + #[test] fn test_unified_linear_quantized_weight_accessor() { use crate::layers::{QuantizedWeight, UnifiedLinear}; diff --git a/src/lib/mlxcel-core/src/utils.rs b/src/lib/mlxcel-core/src/utils.rs index 450ce21..f14c8de 100644 --- a/src/lib/mlxcel-core/src/utils.rs +++ b/src/lib/mlxcel-core/src/utils.rs @@ -100,6 +100,8 @@ pub fn create_causal_mask(size: i32, offset: i32) -> UniquePtr { // Convert to attention mask format: where mask=1 -> 0, where mask=0 -> -inf // Use where_cond to avoid NaN from 0 * -inf + // Intentional FP32: additive attention masks carry 0/-inf sentinels and are + // added to attention scores, not propagated as model activations. let zeros = ffi::zeros(&[size, total_len], dtype::FLOAT32); let neg_inf = ffi::full_f32(&[size, total_len], f32::NEG_INFINITY, dtype::FLOAT32); let bool_mask = ffi::greater(&mask, &zeros); // mask > 0 gives bool mask @@ -211,6 +213,8 @@ pub fn create_causal_mask_with_window( // Convert to attention mask format: where mask=1 -> 0, where mask=0 -> -inf // Use where_cond to avoid NaN from 0 * -inf + // Intentional FP32: additive attention masks carry 0/-inf sentinels and are + // added to attention scores, not propagated as model activations. let zeros = ffi::zeros(&[size, total_len], dtype::FLOAT32); let neg_inf = ffi::full_f32(&[size, total_len], f32::NEG_INFINITY, dtype::FLOAT32); let bool_mask = ffi::greater(&mask, &zeros); // mask > 0 gives bool mask @@ -342,6 +346,7 @@ pub fn gelu_approx(x: &MlxArray) -> UniquePtr { /// * `x` - Input array where last dim will be split into interleaved gelu/linear parts /// * `limit` - Clipping limit for numerical stability pub fn gegelu(x: &MlxArray, limit: f32) -> UniquePtr { + let x_dtype = ffi::array_dtype(x); let shape = ffi::array_shape(x); let ndim = shape.len(); let last_dim = shape[ndim - 1]; @@ -373,22 +378,27 @@ pub fn gegelu(x: &MlxArray, limit: f32) -> UniquePtr { let linear_part = ffi::squeeze_axis(&linear_part, ndim as i32); // Clip both parts for numerical stability - let neg_limit = ffi::full_f32(&[1], -limit, dtype::FLOAT32); - let pos_limit = ffi::full_f32(&[1], limit, dtype::FLOAT32); + let neg_limit = ffi::full_f32(&[1], -limit, x_dtype); + let pos_limit = ffi::full_f32(&[1], limit, x_dtype); let a_gelu = ffi::clip(&gelu_part, &neg_limit, &pos_limit); let a_linear = ffi::clip(&linear_part, &neg_limit, &pos_limit); // Apply GELU approximation: x * sigmoid(1.702 * x) - let coef = ffi::full_f32(&[1], 1.702, dtype::FLOAT32); + let coef = ffi::full_f32(&[1], 1.702, x_dtype); let scaled = ffi::multiply(&coef, &a_gelu); let sigmoid_x = ffi::sigmoid(&scaled); let out_gelu = ffi::multiply(&a_gelu, &sigmoid_x); // Compute: out_gelu * (a_linear + 1.0) - let ones = ffi::full_f32(&[1], 1.0, dtype::FLOAT32); + let ones = ffi::full_f32(&[1], 1.0, x_dtype); let linear_plus_one = ffi::add(&a_linear, &ones); - ffi::multiply(&out_gelu, &linear_plus_one) + let out = ffi::multiply(&out_gelu, &linear_plus_one); + if ffi::array_dtype(&out) == x_dtype { + out + } else { + ffi::astype(&out, x_dtype) + } } // Gemma-specific Functions. @@ -431,7 +441,8 @@ pub fn clip_residual_f16(x: &MlxArray, y: &MlxArray) -> UniquePtr { // float16 max is approximately 65504 let bound = 65504.0f32; - // Cast to f32 + // Intentional FP32: the residual is widened only for overflow-safe clipping + // and is cast back to f16 before returning. let x_f32 = ffi::astype(x, dtype::FLOAT32); let y_f32 = ffi::astype(y, dtype::FLOAT32); @@ -522,6 +533,8 @@ pub fn create_padded_prefill_mask( let combined = ffi::multiply(&causal, &valid_mask); // Convert to additive mask: 1 → 0.0, 0 → -inf + // Intentional FP32: additive attention masks carry 0/-inf sentinels and are + // added to attention scores, not propagated as model activations. let zeros = ffi::zeros(&[padded_len, total_kv], dtype::FLOAT32); let neg_inf = ffi::full_f32(&[padded_len, total_kv], f32::NEG_INFINITY, dtype::FLOAT32); let bool_mask = ffi::greater(&combined, &zeros); diff --git a/src/models/deepseek.rs b/src/models/deepseek.rs index a6e994e..3501796 100644 --- a/src/models/deepseek.rs +++ b/src/models/deepseek.rs @@ -381,13 +381,11 @@ impl MoE { // Apply experts let expert_out = self.experts.forward(&x_flat, &indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let y = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let y = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Add shared experts if present let result = if let Some(shared) = &self.shared_experts { diff --git a/src/models/deepseek_v3.rs b/src/models/deepseek_v3.rs index 73d1fde..91bf680 100644 --- a/src/models/deepseek_v3.rs +++ b/src/models/deepseek_v3.rs @@ -689,13 +689,11 @@ impl MoEBlock { // Expert computation let y = self.switch_mlp.forward(x, &indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - y.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let mut result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let mut result = crate::models::switch_layers::moe_weighted_sum( + &y, + &scores, + mlxcel_core::array_dtype(x), + ); // Add shared experts if present if let Some(ref shared) = self.shared_experts { diff --git a/src/models/deepseek_v32.rs b/src/models/deepseek_v32.rs index d1582d1..81a5e74 100644 --- a/src/models/deepseek_v32.rs +++ b/src/models/deepseek_v32.rs @@ -532,13 +532,11 @@ impl MoEBlock { let (indices, scores) = self.gate.forward(x); let y = self.experts.forward(x, &indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - y.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let mut result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let mut result = crate::models::switch_layers::moe_weighted_sum( + &y, + &scores, + mlxcel_core::array_dtype(x), + ); if let Some(ref shared) = self.shared_experts { let shared_out = shared.forward(x); diff --git a/src/models/ernie4_5_moe.rs b/src/models/ernie4_5_moe.rs index b745126..905cefc 100644 --- a/src/models/ernie4_5_moe.rs +++ b/src/models/ernie4_5_moe.rs @@ -511,20 +511,18 @@ impl MoEBlock { // Normalize scores let score_sum = mlxcel_core::sum_axis(&scores, -1, true); - let score_sum = - mlxcel_core::maximum(&score_sum, &mlxcel_core::from_slice_f32(&[1e-12], &[1])); + let eps = mlxcel_core::full_f32(&[1], 1e-12, mlxcel_core::array_dtype(&score_sum)); + let score_sum = mlxcel_core::maximum(&score_sum, &eps); let scores = mlxcel_core::divide(&scores, &score_sum); // Apply routed experts let expert_out = self.switch_mlp.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let mut result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let mut result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Add shared experts if present if let Some(ref shared) = self.shared_experts { diff --git a/src/models/exaone_moe.rs b/src/models/exaone_moe.rs index e113151..ca632a0 100644 --- a/src/models/exaone_moe.rs +++ b/src/models/exaone_moe.rs @@ -419,13 +419,11 @@ impl ExaoneMoE { // Apply experts - returns [n_tokens, k, hidden] let expert_out = self.switch_mlp.forward(&x_flat, &indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let mut output = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let mut output = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Add shared experts output if present if let Some(ref shared) = self.shared_experts { diff --git a/src/models/glm4_moe.rs b/src/models/glm4_moe.rs index 5163076..b214f9c 100644 --- a/src/models/glm4_moe.rs +++ b/src/models/glm4_moe.rs @@ -736,13 +736,11 @@ impl Glm4Moe { // Apply experts - returns [n_tokens, k, hidden] let expert_out = self.experts.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - topk_scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let mut result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let mut result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &topk_scores, + mlxcel_core::array_dtype(&x_flat), + ); // Add shared expert if present if let Some(ref shared) = self.shared_expert { diff --git a/src/models/glm4_moe_lite.rs b/src/models/glm4_moe_lite.rs index 27da3df..b64c86b 100644 --- a/src/models/glm4_moe_lite.rs +++ b/src/models/glm4_moe_lite.rs @@ -661,13 +661,11 @@ impl MoELayer { // Apply experts let expert_out = self.switch_mlp.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - topk_scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let mut result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let mut result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &topk_scores, + mlxcel_core::array_dtype(&x_flat), + ); // Add shared expert if let Some(ref shared) = self.shared_experts { diff --git a/src/models/gpt_oss.rs b/src/models/gpt_oss.rs index a6c76b1..12aeeac 100644 --- a/src/models/gpt_oss.rs +++ b/src/models/gpt_oss.rs @@ -409,20 +409,20 @@ fn gpt_oss_swiglu(x_linear: &MlxArray, x_glu: &MlxArray, limit: f32) -> UniquePt let alpha = 1.702f32; // Clamp values - let neg_limit = mlxcel_core::from_slice_f32(&[-limit], &[1]); - let pos_limit = mlxcel_core::from_slice_f32(&[limit], &[1]); + let neg_limit = mlxcel_core::full_f32(&[1], -limit, input_dtype); + let pos_limit = mlxcel_core::full_f32(&[1], limit, input_dtype); let x_glu = mlxcel_core::minimum(x_glu, &pos_limit); let x_linear = mlxcel_core::maximum(x_linear, &neg_limit); let x_linear = mlxcel_core::minimum(&x_linear, &pos_limit); // glu_scaled = alpha * x_glu -> sigmoid -> out_glu = x_glu * sig - let alpha_arr = mlxcel_core::from_slice_f32(&[alpha], &[1]); + let alpha_arr = mlxcel_core::full_f32(&[1], alpha, input_dtype); let glu_scaled = mlxcel_core::multiply(&alpha_arr, &x_glu); let sig = mlxcel_core::sigmoid(&glu_scaled); let out_glu = mlxcel_core::multiply(&x_glu, &sig); // (x_linear + 1) * out_glu - let one = mlxcel_core::from_slice_f32(&[1.0], &[1]); + let one = mlxcel_core::full_f32(&[1], 1.0, input_dtype); let x_linear_plus_1 = mlxcel_core::add(&x_linear, &one); let result = mlxcel_core::multiply(&out_glu, &x_linear_plus_1); mlxcel_core::astype(&result, input_dtype) @@ -554,15 +554,11 @@ impl MLPBlock { // Apply experts -> [n_tokens, k, hidden] let expert_out = self.experts.forward(&x_flat, &topk_indices); - // Weighted sum over experts. Keep the same op shape as mlx-lm's - // `x * expand_dims(expert_weights, -1); x.sum(axis=-2)`. Using - // `einsum("nkh,nk->nh")` promotes the contraction to f32 and slows - // GptOss decode on M5. - let scores_exp = mlxcel_core::expand_dims(&scores, -1); - let scores_exp = mlxcel_core::astype(&scores_exp, mlxcel_core::array_dtype(&expert_out)); - let weighted = mlxcel_core::multiply(&expert_out, &scores_exp); - let result = mlxcel_core::sum_axis(&weighted, -2, false); - let result = mlxcel_core::astype(&result, mlxcel_core::array_dtype(&x_flat)); + let result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Reshape back if orig_shape.len() > 2 { @@ -1287,3 +1283,28 @@ impl GptOssStageModel { .map(|idx| caches[idx].as_interface().offset()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn gpt_oss_swiglu_fallback_preserves_bf16_and_f16_dtype() { + for dtype in [mlxcel_core::dtype::BFLOAT16, mlxcel_core::dtype::FLOAT16] { + let x_linear = mlxcel_core::astype( + &mlxcel_core::from_slice_f32(&[-4.0, -1.0, 2.0, 4.0], &[1, 4]), + dtype, + ); + let x_glu = mlxcel_core::astype( + &mlxcel_core::from_slice_f32(&[-2.0, 0.5, 2.0, 5.0], &[1, 4]), + dtype, + ); + + let out = gpt_oss_swiglu(&x_linear, &x_glu, 6.0); + mlxcel_core::eval(&out); + + assert_eq!(mlxcel_core::array_shape(&out), vec![1, 4]); + assert_eq!(mlxcel_core::array_dtype(&out), dtype); + } + } +} diff --git a/src/models/hunyuan_moe.rs b/src/models/hunyuan_moe.rs index 6a5d577..8f3d90a 100644 --- a/src/models/hunyuan_moe.rs +++ b/src/models/hunyuan_moe.rs @@ -522,17 +522,11 @@ impl MoeBlock { // Apply experts let expert_out = self.switch_mlp.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let mut result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; - - // Convert back to original dtype - let expert_dtype = mlxcel_core::array_dtype(&expert_out); - result = mlxcel_core::astype(&result, expert_dtype); + let mut result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Add shared expert output if present if let Some(ref shared_mlp) = self.shared_mlp { diff --git a/src/models/kimi_linear.rs b/src/models/kimi_linear.rs index 8766409..4981ecf 100644 --- a/src/models/kimi_linear.rs +++ b/src/models/kimi_linear.rs @@ -842,13 +842,11 @@ impl KimiSparseMoE { // Expert computation let expert_out = self.switch_mlp.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - topk_scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let mut y = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let mut y = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &topk_scores, + mlxcel_core::array_dtype(&x_flat), + ); // Shared experts if let Some(ref shared) = self.shared_experts { diff --git a/src/models/minimax.rs b/src/models/minimax.rs index 16c878f..9cdbef0 100644 --- a/src/models/minimax.rs +++ b/src/models/minimax.rs @@ -282,12 +282,11 @@ impl SparseMoeBlock { // Apply experts - returns [n_tokens, k, hidden] let expert_out = self.experts.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum("nkh,nk->nh") - let operands: [*const MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - norm_scores.as_ref().unwrap() as *const _, - ]; - let result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &norm_scores, + mlxcel_core::array_dtype(&x_flat), + ); // Reshape back to original shape if orig_shape.len() > 2 { diff --git a/src/models/mistral4.rs b/src/models/mistral4.rs index 61ba422..ddc0426 100644 --- a/src/models/mistral4.rs +++ b/src/models/mistral4.rs @@ -478,13 +478,11 @@ impl Mistral4MoE { // Dispatch to selected experts let y = self.switch_mlp.forward(x, &inds); - // Weighted sum: einsum("nkh,nk->nh", y, scores) - let operands: [*const MlxArray; 2] = [ - y.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let mut result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let mut result = crate::models::switch_layers::moe_weighted_sum( + &y, + &scores, + mlxcel_core::array_dtype(x), + ); // Add shared expert output if let Some(ref shared) = self.shared_experts { diff --git a/src/models/mixtral.rs b/src/models/mixtral.rs index abb570d..574d277 100644 --- a/src/models/mixtral.rs +++ b/src/models/mixtral.rs @@ -312,13 +312,11 @@ impl SparseMoeBlock { // Apply experts - returns [n_tokens, k, hidden] let expert_out = self.experts.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Reshape back to original shape if orig_shape.len() > 2 { diff --git a/src/models/moondream3.rs b/src/models/moondream3.rs index 2d1db2b..22f49f8 100644 --- a/src/models/moondream3.rs +++ b/src/models/moondream3.rs @@ -390,11 +390,11 @@ impl SparseMoeMlp { let hidden = self.fc2.forward(&hidden, &topk, false); let hidden = mlxcel_core::squeeze_axis(&hidden, -2); - let operands: [*const MlxArray; 2] = [ - hidden.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - let combined = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let combined = crate::models::switch_layers::moe_weighted_sum( + &hidden, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); if orig_shape.len() > 2 { mlxcel_core::reshape(&combined, &orig_shape) diff --git a/src/models/olmoe.rs b/src/models/olmoe.rs index 48b1dfc..b06ddcb 100644 --- a/src/models/olmoe.rs +++ b/src/models/olmoe.rs @@ -336,13 +336,11 @@ impl SparseMoeBlock { // Apply experts - returns [n_tokens, k, hidden] let expert_out = self.experts.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Reshape back to original shape if orig_shape.len() > 2 { diff --git a/src/models/phimoe.rs b/src/models/phimoe.rs index 0fdba74..428c9c4 100644 --- a/src/models/phimoe.rs +++ b/src/models/phimoe.rs @@ -360,13 +360,11 @@ impl SparseMoeBlock { // Apply experts - returns [n_tokens, k, hidden] let expert_out = self.experts.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Reshape back to original shape if orig_shape.len() > 2 { diff --git a/src/models/qwen2_moe.rs b/src/models/qwen2_moe.rs index 895b8be..41badcb 100644 --- a/src/models/qwen2_moe.rs +++ b/src/models/qwen2_moe.rs @@ -436,13 +436,11 @@ impl SparseMoeBlock { // Apply experts let expert_out = self.experts.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let expert_out_sum = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let expert_out_sum = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Compute shared expert output let shared_output = self.shared_expert.forward(&x_flat); diff --git a/src/models/qwen3_5.rs b/src/models/qwen3_5.rs index df95355..9ecb20e 100644 --- a/src/models/qwen3_5.rs +++ b/src/models/qwen3_5.rs @@ -2348,7 +2348,7 @@ pub fn sanitize_weights(mut weights: WeightMap, config: &Qwen35Config) -> Weight let v = weights.get(k.as_str()).unwrap(); let ndim = mlxcel_core::array_shape(v).len(); if ndim == 1 { - let one = mlxcel_core::full_f32(&[1], 1.0, dtype::FLOAT32); + let one = mlxcel_core::full_f32(&[1], 1.0, mlxcel_core::array_dtype(v)); let shifted = mlxcel_core::add(v, &one); weights.insert(k.clone(), shifted); } diff --git a/src/models/qwen3_moe.rs b/src/models/qwen3_moe.rs index 8ba1494..fcc86ba 100644 --- a/src/models/qwen3_moe.rs +++ b/src/models/qwen3_moe.rs @@ -336,13 +336,11 @@ impl SparseMoeBlock { // Apply experts - returns [n_tokens, k, hidden] let expert_out = self.experts.forward(&x_flat, &topk_indices); - // Weighted sum over experts: [n_tokens, k, hidden] * [n_tokens, k] -> [n_tokens, hidden] - // einsum fuses expand_dims + multiply + sum_axis into single kernel - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - let result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Reshape back to original shape if orig_shape.len() > 2 { @@ -387,11 +385,11 @@ impl SparseMoeBlock { let expert_ms = expert_start.elapsed().as_secs_f64() * 1000.0; let combine_start = std::time::Instant::now(); - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - let result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); let result = if orig_shape.len() > 2 { mlxcel_core::reshape(&result, &orig_shape) } else { diff --git a/src/models/qwen3_next.rs b/src/models/qwen3_next.rs index 8abc0a3..15115a5 100644 --- a/src/models/qwen3_next.rs +++ b/src/models/qwen3_next.rs @@ -971,13 +971,11 @@ impl SparseMoeBlock { // Expert computation let expert_out = self.experts.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let y = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let y = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Shared expert let shared_y = self.shared_expert.forward(&x_flat); diff --git a/src/models/qwen3_vl_moe.rs b/src/models/qwen3_vl_moe.rs index e0254b1..c4a19c9 100644 --- a/src/models/qwen3_vl_moe.rs +++ b/src/models/qwen3_vl_moe.rs @@ -560,13 +560,11 @@ impl SparseMoeBlock { // Apply experts - returns [n_tokens, k, hidden] let expert_out = self.experts.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &scores, + mlxcel_core::array_dtype(&x_flat), + ); // Reshape back to original shape if orig_shape.len() > 2 { diff --git a/src/models/solar_open.rs b/src/models/solar_open.rs index 2189f46..d13387f 100644 --- a/src/models/solar_open.rs +++ b/src/models/solar_open.rs @@ -641,12 +641,11 @@ impl MoE { // Apply experts via SwitchGLU let expert_out = self.switch_mlp.forward(&x_flat, &topk_indices); - // Weighted sum: einsum("nkh,nk->nh") - let operands: [*const MlxArray; 2] = [ - expert_out.as_ref().unwrap() as *const _, - topk_scores.as_ref().unwrap() as *const _, - ]; - let mut result = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; + let mut result = crate::models::switch_layers::moe_weighted_sum( + &expert_out, + &topk_scores, + mlxcel_core::array_dtype(&x_flat), + ); // Add shared expert if present if let Some(ref shared) = self.shared_expert { diff --git a/src/models/step3p5.rs b/src/models/step3p5.rs index 9ef4ab9..2dc3e60 100644 --- a/src/models/step3p5.rs +++ b/src/models/step3p5.rs @@ -717,17 +717,8 @@ impl Step3p5MoE { // Expert computation let y = self.switch_mlp.forward(x, &indices); - // Weighted sum over experts: einsum fuses expand_dims + multiply + sum_axis - let operands: [*const mlxcel_core::MlxArray; 2] = [ - y.as_ref().unwrap() as *const _, - scores.as_ref().unwrap() as *const _, - ]; - // SAFETY: operands are valid pointers to MlxArray owned by UniquePtr in this scope - let routed_output = unsafe { mlxcel_core::einsum("nkh,nk->nh", &operands) }; - - // Cast back to original dtype let x_dtype = mlxcel_core::array_dtype(x); - let routed_output = mlxcel_core::astype(&routed_output, x_dtype); + let routed_output = crate::models::switch_layers::moe_weighted_sum(&y, &scores, x_dtype); // Add shared expert let shared_output = self.share_expert.forward(x); diff --git a/src/models/switch_layers.rs b/src/models/switch_layers.rs index ab46c6b..80bf67b 100644 --- a/src/models/switch_layers.rs +++ b/src/models/switch_layers.rs @@ -268,6 +268,33 @@ fn scatter_unsort(x: &MlxArray, inv_order: &MlxArray, orig_shape: &[i32]) -> Uni mlxcel_core::squeeze_axis(&reshaped, 2) } +/// Weighted sum over selected expert outputs while preserving the residual dtype. +/// +/// Used by: DeepSeek, DeepSeekV3, DeepSeekV32, ExaOneMoe, Ernie4_5Moe, +/// GLM4Moe, GLM4MoeLite, GptOss, HunyuanMoe, KimiLinear, MiniMax, +/// Mistral4, Mixtral, Moondream3, OLMoE, PhiMoE, Qwen2Moe, Qwen3Moe, +/// Qwen3Next, Qwen3VLMoe, SolarOpen, Step3p5 +/// +/// The old `nkh,nk->nh` einsum contraction promotes the combine to float32 +/// on M5 for bf16/f16 activations. Match mlx-lm's `y * scores[..., None]` +/// followed by `sum(axis=-2)`, with scores cast to the expert output dtype and +/// the final result restored to the hidden/residual dtype. +pub fn moe_weighted_sum( + expert_out: &MlxArray, + scores: &MlxArray, + output_dtype: i32, +) -> UniquePtr { + let scores_exp = mlxcel_core::expand_dims(scores, -1); + let scores_exp = mlxcel_core::astype(&scores_exp, mlxcel_core::array_dtype(expert_out)); + let weighted = mlxcel_core::multiply(expert_out, &scores_exp); + let summed = mlxcel_core::sum_axis(&weighted, -2, false); + if mlxcel_core::array_dtype(&summed) == output_dtype { + summed + } else { + mlxcel_core::astype(&summed, output_dtype) + } +} + /// Group-based expert masking for MoE gates with n_group > 1. /// /// Selects the top `topk_group` expert groups (by sum of top-2 scores per group) @@ -304,3 +331,46 @@ pub fn group_mask_scores(scores: &MlxArray, n_group: i32, topk_group: i32) -> Un // Flatten back: [n, n_group, experts_per_group] -> [n, n_experts] mlxcel_core::reshape(&grouped, &[n, n_experts]) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn moe_weighted_sum_preserves_bf16_output_dtype() { + let expert_f32 = mlxcel_core::from_slice_f32( + &[ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // + 2.0, 4.0, 6.0, 8.0, 1.0, 3.0, 5.0, 7.0, + ], + &[2, 2, 4], + ); + let expert = mlxcel_core::astype(&expert_f32, dtype::BFLOAT16); + let scores = mlxcel_core::from_slice_f32(&[0.25, 0.75, 0.5, 0.5], &[2, 2]); + + let out = moe_weighted_sum(&expert, &scores, dtype::BFLOAT16); + mlxcel_core::eval(&out); + + assert_eq!(mlxcel_core::array_shape(&out), vec![2, 4]); + assert_eq!(mlxcel_core::array_dtype(&out), dtype::BFLOAT16); + } + + #[test] + fn moe_weighted_sum_preserves_f16_output_dtype() { + let expert_f32 = mlxcel_core::from_slice_f32( + &[ + 1.0, 0.0, 3.0, 0.0, 5.0, 0.0, 7.0, 0.0, // + 0.0, 2.0, 0.0, 4.0, 0.0, 6.0, 0.0, 8.0, + ], + &[2, 2, 4], + ); + let expert = mlxcel_core::astype(&expert_f32, dtype::FLOAT16); + let scores = mlxcel_core::from_slice_f32(&[0.5, 0.5, 0.25, 0.75], &[2, 2]); + + let out = moe_weighted_sum(&expert, &scores, dtype::FLOAT16); + mlxcel_core::eval(&out); + + assert_eq!(mlxcel_core::array_shape(&out), vec![2, 4]); + assert_eq!(mlxcel_core::array_dtype(&out), dtype::FLOAT16); + } +}