Skip to content

Commit dd49781

Browse files
author
Aegis-AI
committed
Merge feat/mtp-speculative-decoding to main
2 parents 3362bab + b9bf50b commit dd49781

2 files changed

Lines changed: 25 additions & 7 deletions

File tree

Libraries/MLXLLM/Models/Gemma4Text.swift

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,11 +1117,8 @@ public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimens
11171117
let scatterIdx2D = selectedCanonicalShaped.reshaped([B * S, totalCandidates]).asType(.int32)
11181118
let selectedLogits2D = selectedLogits.reshaped([B * S, totalCandidates])
11191119
var output2D = output.reshaped([B * S, vocabSize])
1120-
for bsIdx in 0 ..< B * S {
1121-
let idxRow = scatterIdx2D[bsIdx] // [totalCandidates]
1122-
let valRow = selectedLogits2D[bsIdx] // [totalCandidates]
1123-
output2D[bsIdx, idxRow] = valRow
1124-
}
1120+
let rowIndices = MLXArray(0 ..< Int32(B * S)).reshaped([B * S, 1])
1121+
output2D[rowIndices, scatterIdx2D] = selectedLogits2D
11251122
output = output2D.reshaped([B, S, vocabSize])
11261123

11271124
return output

Libraries/MLXLMCommon/Evaluate.swift

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,8 @@ protocol TokenIteratorProtocol: Sequence, IteratorProtocol where Element == Int
503503
var tokenCount: Int { get }
504504
var promptPrefillTime: TimeInterval { get }
505505
var streamingError: SSDStreamingError? { get }
506+
var acceptedDraftTokens: Int { get }
507+
var totalDraftTokens: Int { get }
506508
}
507509

508510
/// Generator of tokens.
@@ -549,6 +551,8 @@ public struct TokenIterator: TokenIteratorProtocol {
549551
var promptPrefillTime: TimeInterval = 0.0
550552
var streamingError: SSDStreamingError?
551553
let ssdErrorLatch = SSDStreamingErrorLatch()
554+
var acceptedDraftTokens = 0
555+
var totalDraftTokens = 0
552556

553557
/// Initialize a `TokenIterator` with the given tokens. Note: this has been
554558
/// replaced with ``init(input:model:cache:parameters:)``.
@@ -794,6 +798,8 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol {
794798

795799
// Internal metrics
796800
var promptPrefillTime: TimeInterval = 0.0
801+
var acceptedDraftTokens = 0
802+
var totalDraftTokens = 0
797803

798804
/// Initialize a `SpeculativeTokenIterator` with the given input.
799805
///
@@ -1014,6 +1020,9 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol {
10141020
trimPromptCache(mainCache, numTokens: numDraft - accepted)
10151021
trimPromptCache(draftCache, numTokens: Swift.max(numDraft - accepted - 1, 0))
10161022

1023+
self.acceptedDraftTokens += accepted
1024+
self.totalDraftTokens += draftTokens.count
1025+
10171026
// Apply dynamic cache quantization after rewind
10181027
quantizeKVCache(&mainCache)
10191028
quantizeKVCache(&draftCache)
@@ -2228,7 +2237,9 @@ private func generateLoopTask<Handler: TokenLoopHandler>(
22282237
generationTokenCount: tokenCount,
22292238
promptTime: promptTime + iterator.promptPrefillTime,
22302239
generationTime: generateTime,
2231-
stopReason: stopReason ?? .cancelled
2240+
stopReason: stopReason ?? .cancelled,
2241+
acceptedDraftTokens: iterator.acceptedDraftTokens,
2242+
totalDraftTokens: iterator.totalDraftTokens
22322243
)
22332244
_ = continuation.yield(handler.infoEvent(info))
22342245

@@ -2298,6 +2309,12 @@ public struct GenerateCompletionInfo: Sendable {
22982309
/// Reason generation stopped.
22992310
public let stopReason: GenerateStopReason
23002311

2312+
/// Number of accepted draft tokens (if speculative decoding is active).
2313+
public let acceptedDraftTokens: Int
2314+
2315+
/// Total number of draft tokens evaluated (if speculative decoding is active).
2316+
public let totalDraftTokens: Int
2317+
23012318
/// The number of tokens processed per second during the prompt phase.
23022319
public var promptTokensPerSecond: Double {
23032320
Double(promptTokenCount) / promptTime
@@ -2313,13 +2330,17 @@ public struct GenerateCompletionInfo: Sendable {
23132330
generationTokenCount: Int,
23142331
promptTime: TimeInterval,
23152332
generationTime: TimeInterval,
2316-
stopReason: GenerateStopReason = .stop
2333+
stopReason: GenerateStopReason = .stop,
2334+
acceptedDraftTokens: Int = 0,
2335+
totalDraftTokens: Int = 0
23172336
) {
23182337
self.promptTokenCount = promptTokenCount
23192338
self.generationTokenCount = generationTokenCount
23202339
self.promptTime = promptTime
23212340
self.generateTime = generationTime
23222341
self.stopReason = stopReason
2342+
self.acceptedDraftTokens = acceptedDraftTokens
2343+
self.totalDraftTokens = totalDraftTokens
23232344
}
23242345

23252346
public func summary() -> String {

0 commit comments

Comments
 (0)