Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/lib/mlxcel-core/cpp/mlx_cxx_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,42 @@ std::unique_ptr<MlxArray> compiled_swiglu_activation(
return std::make_unique<MlxArray>(std::move(result[0]));
}

// Compiled GptOss SwiGLU activation using the exact mlx-lm formulation:
// x_glu = clip(x_glu, max=7)
// x_linear = clip(x_linear, min=-7, max=7)
// return x_glu * sigmoid(1.702 * x_glu) * (x_linear + 1)
// Used by: GptOss
namespace {
static std::function<std::vector<array>(const std::vector<array>&)>
get_compiled_gpt_oss_swiglu_activation() {
auto fn = [](const std::vector<array>& inputs) -> std::vector<array> {
const auto& x_linear_in = inputs[0];
const auto& x_glu_in = inputs[1];

auto pos_limit = mlx::core::array(7.0f);
auto neg_limit = mlx::core::array(-7.0f);
auto x_glu = mlx::core::minimum(x_glu_in, pos_limit);
auto x_linear = mlx::core::maximum(x_linear_in, neg_limit);
x_linear = mlx::core::minimum(x_linear, pos_limit);

auto glu_scaled = mlx::core::multiply(mlx::core::array(1.702f), x_glu);
auto out_glu = mlx::core::multiply(x_glu, mlx::core::sigmoid(glu_scaled));
auto result = mlx::core::multiply(out_glu, mlx::core::add(x_linear, mlx::core::array(1.0f)));
return {mlx::core::astype(result, x_linear_in.dtype())};
};
return mlx::core::compile(fn, true);
}
}

std::unique_ptr<MlxArray> compiled_gpt_oss_swiglu_activation(
const MlxArray& x_linear,
const MlxArray& x_glu
) {
static auto compiled_fn = get_compiled_gpt_oss_swiglu_activation();
auto result = compiled_fn({x_linear.inner, x_glu.inner});
return std::make_unique<MlxArray>(std::move(result[0]));
}

// Compiled GELU: x * 0.5 * (1 + erf(x / sqrt(2)))
// Used by: Gemma2, Gemma3, StarCoder2
namespace {
Expand Down
8 changes: 8 additions & 0 deletions src/lib/mlxcel-core/cpp/mlx_cxx_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,14 @@ std::unique_ptr<MlxArray> compiled_swiglu_activation(
const MlxArray& x
);

// GptOss SwiGLU activation only - compiled with kernel fusion (shapeless=true)
// output = clipped_gate * sigmoid(1.702 * clipped_gate) * (clipped_up + 1)
// Used by: GptOss
std::unique_ptr<MlxArray> compiled_gpt_oss_swiglu_activation(
const MlxArray& x_linear,
const MlxArray& x_glu
);

// GeGLU activation - compiled with kernel fusion (shapeless=true)
// output = gelu(gate) * x
// Used by: Gemma, Gemma2, Gemma3 MLP layers
Expand Down
18 changes: 18 additions & 0 deletions src/lib/mlxcel-core/src/ffi_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,24 @@ fn test_compiled_swiglu_activation() {
assert!(item_f32(&total) > 0.0);
}

#[test]
fn test_compiled_gpt_oss_swiglu_activation_preserves_input_dtype() {
let x_linear = astype(
&from_slice_f32(&[-8.0, -1.0, 2.0, 8.0], &[1, 4]),
dtype::BFLOAT16,
);
let x_glu = astype(
&from_slice_f32(&[-2.0, 0.5, 2.0, 8.0], &[1, 4]),
dtype::BFLOAT16,
);

let out = compiled_gpt_oss_swiglu_activation(&x_linear, &x_glu);
eval(&out);

assert_eq!(array_shape(&out), vec![1, 4]);
assert_eq!(array_dtype(&out), dtype::BFLOAT16);
}

#[test]
fn test_new_ops() {
let x = from_slice_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]);
Expand Down
8 changes: 8 additions & 0 deletions src/lib/mlxcel-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,14 @@ mod ffi {
/// output = silu(gate) * x
fn compiled_swiglu_activation(gate: &MlxArray, x: &MlxArray) -> UniquePtr<MlxArray>;

/// Compiled GptOss SwiGLU activation with kernel fusion
/// Matches mlx-lm gpt_oss.swiglu: clipped gate/up + sigmoid(1.702*gate).
/// Used by: GptOss
fn compiled_gpt_oss_swiglu_activation(
x_linear: &MlxArray,
x_glu: &MlxArray,
) -> UniquePtr<MlxArray>;

/// Compiled relu_squared: square(maximum(x, 0)) — single fused kernel
fn compiled_relu_squared(x: &MlxArray) -> UniquePtr<MlxArray>;

Expand Down
29 changes: 21 additions & 8 deletions src/models/gpt_oss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,11 @@ impl ExpertLinear {
// out_glu = x_glu * sigmoid(alpha * x_glu)
// return out_glu * (x_linear + 1)
fn gpt_oss_swiglu(x_linear: &MlxArray, x_glu: &MlxArray, limit: f32) -> UniquePtr<MlxArray> {
if (limit - 7.0).abs() <= f32::EPSILON {
return mlxcel_core::compiled_gpt_oss_swiglu_activation(x_linear, x_glu);
}

let input_dtype = mlxcel_core::array_dtype(x_linear);
let alpha = 1.702f32;

// Clamp values
Expand All @@ -419,7 +424,8 @@ fn gpt_oss_swiglu(x_linear: &MlxArray, x_glu: &MlxArray, limit: f32) -> UniquePt
// (x_linear + 1) * out_glu
let one = mlxcel_core::from_slice_f32(&[1.0], &[1]);
let x_linear_plus_1 = mlxcel_core::add(&x_linear, &one);
mlxcel_core::multiply(&out_glu, &x_linear_plus_1)
let result = mlxcel_core::multiply(&out_glu, &x_linear_plus_1);
mlxcel_core::astype(&result, input_dtype)
}

// SwitchGLU for GptOss (custom activation, MXFP4 support)
Expand All @@ -437,22 +443,25 @@ impl GptOssSwitchGLU {
let top_k = indices_shape[1];
let total = n_tokens * top_k;
let do_sort = total >= 64;
let hidden_size = mlxcel_core::array_shape(x)[1];

let x_exp = mlxcel_core::expand_dims(x, -2);
let x_exp = mlxcel_core::expand_dims(&x_exp, -3);
// Python writes this as `mx.expand_dims(x, (-2, -3))`, producing
// [tokens, 1, 1, hidden]. The input here is already flattened to rank
// 2, so a reshape is equivalent and avoids two decode-hot shape ops.
let x_exp = mlxcel_core::reshape(x, &[n_tokens, 1, 1, hidden_size]);

if do_sort {
let (sorted_x, sorted_idx, inv_order) =
crate::models::switch_layers::gather_sort(&x_exp, indices);
let x_gate = self.gate_proj.forward(&sorted_x, &sorted_idx, true);
let x_up = self.up_proj.forward(&sorted_x, &sorted_idx, true);
let x_gate = self.gate_proj.forward(&sorted_x, &sorted_idx, true);
// Python SwitchGLU: activation(x_up, x_gate) → swiglu(x_linear=x_up, x_glu=x_gate)
let activated = gpt_oss_swiglu(&x_up, &x_gate, self.swiglu_limit);
let output = self.down_proj.forward(&activated, &sorted_idx, true);
scatter_unsort(&output, &inv_order, &indices_shape)
} else {
let x_gate = self.gate_proj.forward(&x_exp, indices, false);
let x_up = self.up_proj.forward(&x_exp, indices, false);
let x_gate = self.gate_proj.forward(&x_exp, indices, false);
// Python SwitchGLU: activation(x_up, x_gate) → swiglu(x_linear=x_up, x_glu=x_gate)
let activated = gpt_oss_swiglu(&x_up, &x_gate, self.swiglu_limit);
let output = self.down_proj.forward(&activated, indices, false);
Expand Down Expand Up @@ -540,16 +549,20 @@ impl MLPBlock {

// Softmax over top-k logits
let topk_logits = mlxcel_core::take_along_axis(&logits, &topk_indices, -1);
let scores = mlxcel_core::softmax(&topk_logits, -1);
let scores = mlxcel_core::softmax_precise(&topk_logits, -1);

// Apply experts -> [n_tokens, k, hidden]
let expert_out = self.experts.forward(&x_flat, &topk_indices);

// Weighted sum over experts: einsum("nkh,nk->nh")
// Multiply expert_out by expanded scores and sum over k
// Weighted sum over experts. Keep the same op shape as mlx-lm's
// `x * expand_dims(expert_weights, -1); x.sum(axis=-2)`. Using
// `einsum("nkh,nk->nh")` promotes the contraction to f32 and slows
// GptOss decode on M5.
let scores_exp = mlxcel_core::expand_dims(&scores, -1);
let scores_exp = mlxcel_core::astype(&scores_exp, mlxcel_core::array_dtype(&expert_out));
let weighted = mlxcel_core::multiply(&expert_out, &scores_exp);
let result = mlxcel_core::sum_axis(&weighted, -2, false);
let result = mlxcel_core::astype(&result, mlxcel_core::array_dtype(&x_flat));

// Reshape back
if orig_shape.len() > 2 {
Expand Down
Loading