diff --git a/src/models/solar_open.rs b/src/models/solar_open.rs index d13387f..b6624cc 100644 --- a/src/models/solar_open.rs +++ b/src/models/solar_open.rs @@ -807,14 +807,18 @@ impl SolarOpenModel { caches: &mut [KVCache], mask: Option<&MlxArray>, ) -> UniquePtr { + let eval_layer_outputs = should_eval_layer_outputs(input_ids); let mut h = self.embed_tokens.forward(input_ids); for (i, layer) in self.layers.iter().enumerate() { h = layer.forward(&h, &mut caches[i], mask); - // Eval every layer to prevent the MoE computation graph from - // growing too large (each layer has ~50 ops with 128 experts) - let ptrs = [h.as_ref().unwrap() as *const MlxArray]; - unsafe { mlxcel_core::eval_all(&ptrs) }; + // Keep the graph bounded for multi-token prefill. For single-token + // decode, final-logits evaluation flushes the graph once; forcing + // one sync per layer costs 48 GPU synchronizations per token. + if eval_layer_outputs { + let ptrs = [h.as_ref().unwrap() as *const MlxArray]; + unsafe { mlxcel_core::eval_all(&ptrs) }; + } } let h = self.norm.forward(&h); @@ -910,6 +914,10 @@ impl LanguageModel for SolarOpenModel { // Helper Functions // ============================================================================ +fn should_eval_layer_outputs(input_ids: &MlxArray) -> bool { + mlxcel_core::array_shape(input_ids).last().copied() != Some(1) +} + fn get_weight_copy(weights: &WeightMap, name: &str) -> Result, String> { weights .get(name) @@ -965,4 +973,13 @@ mod tests { assert!(args.is_moe_layer(0)); // All layers are MoE assert_eq!(args.rope_dims(), 128); // Full RoPE } + + #[test] + fn solar_open_skips_per_layer_eval_for_decode_step() { + let decode_ids = mlxcel_core::from_slice_i32(&[1], &[1, 1]); + let prefill_ids = mlxcel_core::from_slice_i32(&[1, 2, 3, 4], &[1, 4]); + + assert!(!should_eval_layer_outputs(&decode_ids)); + assert!(should_eval_layer_outputs(&prefill_ids)); + } }