diff --git a/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp b/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp index 85f518a0..c973a3c3 100644 --- a/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp +++ b/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp @@ -1160,6 +1160,42 @@ std::unique_ptr compiled_swiglu_activation( return std::make_unique(std::move(result[0])); } +// Compiled GptOss SwiGLU activation using the exact mlx-lm formulation: +// x_glu = clip(x_glu, max=7) +// x_linear = clip(x_linear, min=-7, max=7) +// return x_glu * sigmoid(1.702 * x_glu) * (x_linear + 1) +// Used by: GptOss +namespace { + static std::function(const std::vector&)> + get_compiled_gpt_oss_swiglu_activation() { + auto fn = [](const std::vector& inputs) -> std::vector { + const auto& x_linear_in = inputs[0]; + const auto& x_glu_in = inputs[1]; + + auto pos_limit = mlx::core::array(7.0f); + auto neg_limit = mlx::core::array(-7.0f); + auto x_glu = mlx::core::minimum(x_glu_in, pos_limit); + auto x_linear = mlx::core::maximum(x_linear_in, neg_limit); + x_linear = mlx::core::minimum(x_linear, pos_limit); + + auto glu_scaled = mlx::core::multiply(mlx::core::array(1.702f), x_glu); + auto out_glu = mlx::core::multiply(x_glu, mlx::core::sigmoid(glu_scaled)); + auto result = mlx::core::multiply(out_glu, mlx::core::add(x_linear, mlx::core::array(1.0f))); + return {mlx::core::astype(result, x_linear_in.dtype())}; + }; + return mlx::core::compile(fn, true); + } +} + +std::unique_ptr compiled_gpt_oss_swiglu_activation( + const MlxArray& x_linear, + const MlxArray& x_glu +) { + static auto compiled_fn = get_compiled_gpt_oss_swiglu_activation(); + auto result = compiled_fn({x_linear.inner, x_glu.inner}); + return std::make_unique(std::move(result[0])); +} + // Compiled GELU: x * 0.5 * (1 + erf(x / sqrt(2))) // Used by: Gemma2, Gemma3, StarCoder2 namespace { diff --git a/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.h b/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.h index e5cc9145..cc573e58 100644 --- a/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.h +++ b/src/lib/mlxcel-core/cpp/mlx_cxx_bridge.h @@ -398,6 +398,14 @@ std::unique_ptr compiled_swiglu_activation( const MlxArray& x ); +// GptOss SwiGLU activation only - compiled with kernel fusion (shapeless=true) +// output = clipped_gate * sigmoid(1.702 * clipped_gate) * (clipped_up + 1) +// Used by: GptOss +std::unique_ptr compiled_gpt_oss_swiglu_activation( + const MlxArray& x_linear, + const MlxArray& x_glu +); + // GeGLU activation - compiled with kernel fusion (shapeless=true) // output = gelu(gate) * x // Used by: Gemma, Gemma2, Gemma3 MLP layers diff --git a/src/lib/mlxcel-core/src/ffi_tests.rs b/src/lib/mlxcel-core/src/ffi_tests.rs index 47472bf9..de558ee3 100644 --- a/src/lib/mlxcel-core/src/ffi_tests.rs +++ b/src/lib/mlxcel-core/src/ffi_tests.rs @@ -466,6 +466,24 @@ fn test_compiled_swiglu_activation() { assert!(item_f32(&total) > 0.0); } +#[test] +fn test_compiled_gpt_oss_swiglu_activation_preserves_input_dtype() { + let x_linear = astype( + &from_slice_f32(&[-8.0, -1.0, 2.0, 8.0], &[1, 4]), + dtype::BFLOAT16, + ); + let x_glu = astype( + &from_slice_f32(&[-2.0, 0.5, 2.0, 8.0], &[1, 4]), + dtype::BFLOAT16, + ); + + let out = compiled_gpt_oss_swiglu_activation(&x_linear, &x_glu); + eval(&out); + + assert_eq!(array_shape(&out), vec![1, 4]); + assert_eq!(array_dtype(&out), dtype::BFLOAT16); +} + #[test] fn test_new_ops() { let x = from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]); diff --git a/src/lib/mlxcel-core/src/lib.rs b/src/lib/mlxcel-core/src/lib.rs index d6589551..f7e9107f 100644 --- a/src/lib/mlxcel-core/src/lib.rs +++ b/src/lib/mlxcel-core/src/lib.rs @@ -525,6 +525,14 @@ mod ffi { /// output = silu(gate) * x fn compiled_swiglu_activation(gate: &MlxArray, x: &MlxArray) -> UniquePtr; + /// Compiled GptOss SwiGLU activation with kernel fusion + /// Matches mlx-lm gpt_oss.swiglu: clipped gate/up + sigmoid(1.702*gate). + /// Used by: GptOss + fn compiled_gpt_oss_swiglu_activation( + x_linear: &MlxArray, + x_glu: &MlxArray, + ) -> UniquePtr; + /// Compiled relu_squared: square(maximum(x, 0)) — single fused kernel fn compiled_relu_squared(x: &MlxArray) -> UniquePtr; diff --git a/src/models/gpt_oss.rs b/src/models/gpt_oss.rs index 65341948..a6c76b1a 100644 --- a/src/models/gpt_oss.rs +++ b/src/models/gpt_oss.rs @@ -401,6 +401,11 @@ impl ExpertLinear { // out_glu = x_glu * sigmoid(alpha * x_glu) // return out_glu * (x_linear + 1) fn gpt_oss_swiglu(x_linear: &MlxArray, x_glu: &MlxArray, limit: f32) -> UniquePtr { + if (limit - 7.0).abs() <= f32::EPSILON { + return mlxcel_core::compiled_gpt_oss_swiglu_activation(x_linear, x_glu); + } + + let input_dtype = mlxcel_core::array_dtype(x_linear); let alpha = 1.702f32; // Clamp values @@ -419,7 +424,8 @@ fn gpt_oss_swiglu(x_linear: &MlxArray, x_glu: &MlxArray, limit: f32) -> UniquePt // (x_linear + 1) * out_glu let one = mlxcel_core::from_slice_f32(&[1.0], &[1]); let x_linear_plus_1 = mlxcel_core::add(&x_linear, &one); - mlxcel_core::multiply(&out_glu, &x_linear_plus_1) + let result = mlxcel_core::multiply(&out_glu, &x_linear_plus_1); + mlxcel_core::astype(&result, input_dtype) } // SwitchGLU for GptOss (custom activation, MXFP4 support) @@ -437,22 +443,25 @@ impl GptOssSwitchGLU { let top_k = indices_shape[1]; let total = n_tokens * top_k; let do_sort = total >= 64; + let hidden_size = mlxcel_core::array_shape(x)[1]; - let x_exp = mlxcel_core::expand_dims(x, -2); - let x_exp = mlxcel_core::expand_dims(&x_exp, -3); + // Python writes this as `mx.expand_dims(x, (-2, -3))`, producing + // [tokens, 1, 1, hidden]. The input here is already flattened to rank + // 2, so a reshape is equivalent and avoids two decode-hot shape ops. + let x_exp = mlxcel_core::reshape(x, &[n_tokens, 1, 1, hidden_size]); if do_sort { let (sorted_x, sorted_idx, inv_order) = crate::models::switch_layers::gather_sort(&x_exp, indices); - let x_gate = self.gate_proj.forward(&sorted_x, &sorted_idx, true); let x_up = self.up_proj.forward(&sorted_x, &sorted_idx, true); + let x_gate = self.gate_proj.forward(&sorted_x, &sorted_idx, true); // Python SwitchGLU: activation(x_up, x_gate) → swiglu(x_linear=x_up, x_glu=x_gate) let activated = gpt_oss_swiglu(&x_up, &x_gate, self.swiglu_limit); let output = self.down_proj.forward(&activated, &sorted_idx, true); scatter_unsort(&output, &inv_order, &indices_shape) } else { - let x_gate = self.gate_proj.forward(&x_exp, indices, false); let x_up = self.up_proj.forward(&x_exp, indices, false); + let x_gate = self.gate_proj.forward(&x_exp, indices, false); // Python SwitchGLU: activation(x_up, x_gate) → swiglu(x_linear=x_up, x_glu=x_gate) let activated = gpt_oss_swiglu(&x_up, &x_gate, self.swiglu_limit); let output = self.down_proj.forward(&activated, indices, false); @@ -540,16 +549,20 @@ impl MLPBlock { // Softmax over top-k logits let topk_logits = mlxcel_core::take_along_axis(&logits, &topk_indices, -1); - let scores = mlxcel_core::softmax(&topk_logits, -1); + let scores = mlxcel_core::softmax_precise(&topk_logits, -1); // Apply experts -> [n_tokens, k, hidden] let expert_out = self.experts.forward(&x_flat, &topk_indices); - // Weighted sum over experts: einsum("nkh,nk->nh") - // Multiply expert_out by expanded scores and sum over k + // Weighted sum over experts. Keep the same op shape as mlx-lm's + // `x * expand_dims(expert_weights, -1); x.sum(axis=-2)`. Using + // `einsum("nkh,nk->nh")` promotes the contraction to f32 and slows + // GptOss decode on M5. let scores_exp = mlxcel_core::expand_dims(&scores, -1); + let scores_exp = mlxcel_core::astype(&scores_exp, mlxcel_core::array_dtype(&expert_out)); let weighted = mlxcel_core::multiply(&expert_out, &scores_exp); let result = mlxcel_core::sum_axis(&weighted, -2, false); + let result = mlxcel_core::astype(&result, mlxcel_core::array_dtype(&x_flat)); // Reshape back if orig_shape.len() > 2 {