Skip to content

Commit e3fb72a

Browse files
author
Aegis-AI
committed
Fix GPU Hang: flush Metal graph with eval() after MTPTokenIterator verification pass
On hybrid SSM/attention models (Qwen35), the recurrent GatedDeltaNet layers accumulate un-evaluated MLX graph nodes across each speculateRound(). Without an explicit eval() after callMTP(), the Metal command buffer grows across multiple speculation rounds until it triggers the GPU Watchdog (kIOGPUCommandBufferCallbackErrorHang). Adding eval(mtpResult) immediately after the verification forward pass flushes the accumulated graph, preventing the Metal timeout.
1 parent b42d9a0 commit e3fb72a

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,13 @@ public struct MTPTokenIterator: TokenIteratorProtocol {
12271227

12281228
let mtpResult = model.callMTP(verifyInput.tokens[.newAxis], cache: cache, mtpCaches: mtpCaches)
12291229
guard !mtpResult.isEmpty else { return }
1230-
1230+
1231+
// Flush the Metal command buffer immediately after the verification forward pass.
1232+
// On hybrid SSM/attention models (e.g. Qwen35), the recurrent SSM layers accumulate
1233+
// un-evaluated graph nodes across rounds. Without an explicit sync here the Metal
1234+
// command buffer grows until it triggers the GPU Watchdog.
1235+
eval(mtpResult)
1236+
12311237
let mainLogits = mtpResult[0]
12321238

12331239
let mainTokens: MLXArray

0 commit comments

Comments
 (0)