Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/models/solar_open.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,14 +807,18 @@ impl SolarOpenModel {
caches: &mut [KVCache],
mask: Option<&MlxArray>,
) -> UniquePtr<MlxArray> {
let eval_layer_outputs = should_eval_layer_outputs(input_ids);
let mut h = self.embed_tokens.forward(input_ids);

for (i, layer) in self.layers.iter().enumerate() {
h = layer.forward(&h, &mut caches[i], mask);
// Eval every layer to prevent the MoE computation graph from
// growing too large (each layer has ~50 ops with 128 experts)
let ptrs = [h.as_ref().unwrap() as *const MlxArray];
unsafe { mlxcel_core::eval_all(&ptrs) };
// Keep the graph bounded for multi-token prefill. For single-token
// decode, final-logits evaluation flushes the graph once; forcing
// one sync per layer costs 48 GPU synchronizations per token.
if eval_layer_outputs {
let ptrs = [h.as_ref().unwrap() as *const MlxArray];
unsafe { mlxcel_core::eval_all(&ptrs) };
}
}

let h = self.norm.forward(&h);
Expand Down Expand Up @@ -910,6 +914,10 @@ impl LanguageModel for SolarOpenModel {
// Helper Functions
// ============================================================================

fn should_eval_layer_outputs(input_ids: &MlxArray) -> bool {
mlxcel_core::array_shape(input_ids).last().copied() != Some(1)
}

fn get_weight_copy(weights: &WeightMap, name: &str) -> Result<UniquePtr<MlxArray>, String> {
weights
.get(name)
Expand Down Expand Up @@ -965,4 +973,13 @@ mod tests {
assert!(args.is_moe_layer(0)); // All layers are MoE
assert_eq!(args.rope_dims(), 128); // Full RoPE
}

#[test]
fn solar_open_skips_per_layer_eval_for_decode_step() {
let decode_ids = mlxcel_core::from_slice_i32(&[1], &[1, 1]);
let prefill_ids = mlxcel_core::from_slice_i32(&[1, 2, 3, 4], &[1, 4]);

assert!(!should_eval_layer_outputs(&decode_ids));
assert!(should_eval_layer_outputs(&prefill_ids));
}
}