Skip to content

Commit 3db113b

Browse files
committed
feat(nemotron_asr): incremental NemotronASRStreamSession for live mic
generateStream computes the whole-utterance mel up front, so a live caller only gets text after the entire buffer is in. NemotronASRStreamSession ingests audio as it arrives (step([Float]) -> Delta / finish()) and emits text per frozen chunk with the model's native delay. Bit-identical to generateStream(wholeAudio): normalize==NA makes mel frames independent; the centered STFT freezes frame m once m*hop+nFft/2<=bufferLen, so the session only feeds the encoder frozen whole chunks (tail flushed in finish()). Encoder + greedy RNN-T state moved into NemotronASRStreamEncoderState / NemotronASRStreamRNNTState; the chunk loop and the RNN-T decode are now shared by generateStream and the session (SSOT, no duplicated loops). CLI --stream drives the session for Nemotron (mirrors Voxtral). Tests assert session(chunked) == generateStream(whole) and feed-granularity invariance on a tiny synthetic model. Verified on mlx-community/nemotron-3.5-asr-streaming-0.6b-8bit (FLEURS-ru): 31/31 session==generateStream@{80ms,480ms feed}; TTFT 2.49s -> 0.09s on a 47s clip.
1 parent 3f6b055 commit 3db113b

5 files changed

Lines changed: 439 additions & 73 deletions

File tree

Sources/MLXAudioSTT/Models/NemotronASR/NemotronASRModel.swift

Lines changed: 8 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -141,61 +141,18 @@ public final class NemotronASRModel: Module, STTGenerationModel {
141141
let frameSeconds = Double(self.encoderConfig.subsamplingFactor * self.preprocessConfig.hopLength)
142142
/ Double(sampleRate)
143143

144-
var results: [NemoAlignedToken] = []
145-
var lastToken = self.blankTokenID
146-
var decoderState: NemoLSTMState?
144+
let rnntState = NemotronASRStreamRNNTState(blankToken: self.blankTokenID)
147145
var previousText = ""
148-
var globalTime = 0
149146

150147
// Cache-aware streaming: incremental subsampling + per-layer attn/conv
151-
// caches. Token-identical to decode() at the native chunk size.
148+
// caches, greedy RNN-T per chunk. Token-identical to decode() at the
149+
// native chunk size; shares both loops with NemotronASRStreamSession.
152150
self.cacheAwareStreamEncode(mel, language: generationParameters.language) { prompted in
153-
let chunkLen = prompted.shape[1]
154-
var time = 0
155-
var newSymbols = 0
156-
while time < chunkLen {
157-
let frame = prompted[0..., time..<(time + 1), 0...]
158-
let currentToken: MLXArray? = lastToken == self.blankTokenID
159-
? nil
160-
: MLXArray(Int32(lastToken)).reshaped([1, 1]).asType(.int32)
161-
let decoderOutput = self.decoder(currentToken, state: decoderState)
162-
let pred = decoderOutput.0.asType(frame.dtype)
163-
let proposedState: NemoLSTMState = (
164-
hidden: decoderOutput.1.hidden?.asType(frame.dtype),
165-
cell: decoderOutput.1.cell?.asType(frame.dtype)
166-
)
167-
let jointOutput = self.joint(frame, pred)
168-
let token = jointOutput.argMax(axis: -1).item(Int.self)
169-
let step = NemoDecodingLogic.rnntStep(
170-
predictedToken: token,
171-
blankToken: self.blankTokenID,
172-
time: time,
173-
newSymbols: newSymbols,
174-
maxSymbols: self.maxSymbols
175-
)
176-
if step.emittedToken {
177-
lastToken = token
178-
decoderState = proposedState
179-
if !NemotronASRTokenizer.isSpecialToken(token, vocabulary: self.vocabulary) {
180-
results.append(
181-
NemoAlignedToken(
182-
id: token,
183-
text: NemotronASRTokenizer.decode(tokens: [token], vocabulary: self.vocabulary),
184-
start: Double(globalTime + time) * frameSeconds,
185-
duration: frameSeconds
186-
)
187-
)
188-
}
189-
}
190-
time = step.nextTime
191-
newSymbols = step.nextNewSymbols
192-
}
193-
globalTime += chunkLen
151+
self.streamRNNTDecode(prompted, state: rnntState, frameSeconds: frameSeconds)
194152

195-
let currentResult = NemoAlignment.sentencesToResult(
196-
NemoAlignment.tokensToSentences(results)
197-
)
198-
let fullText = currentResult.text
153+
let fullText = NemoAlignment.sentencesToResult(
154+
NemoAlignment.tokensToSentences(rnntState.results)
155+
).text
199156
let nextText = fullText.hasPrefix(previousText)
200157
? String(fullText.dropFirst(previousText.count))
201158
: fullText
@@ -206,7 +163,7 @@ public final class NemotronASRModel: Module, STTGenerationModel {
206163
}
207164

208165
let finalResult = NemoAlignment.sentencesToResult(
209-
NemoAlignment.tokensToSentences(results)
166+
NemoAlignment.tokensToSentences(rnntState.results)
210167
)
211168
continuation.yield(
212169
.result(
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import Foundation
2+
import MLX
3+
import MLXAudioCore
4+
import MLXNN
5+
6+
// True incremental (online) streaming for Nemotron 3.5 ASR.
7+
//
8+
// The offline `generateStream(...)` computes the mel of the whole utterance up
9+
// front and only then walks the cache-aware encoder. A live caller (e.g. a mic
10+
// feeding 80 ms chunks) instead wants to push audio as it arrives and read text
11+
// back with the model's native chunk delay. This session does exactly that, while
12+
// staying bit-identical to the offline encode.
13+
//
14+
// Why it is bit-identical (transcript == generateStream(wholeAudio)):
15+
// * The preprocessor uses `normalize: "NA"` (verified on the shipped checkpoint),
16+
// so each mel frame is an independent function of a fixed sample window — no
17+
// per-utterance mean/std that would shift as audio grows. preemph is causal.
18+
// * The STFT centers with `nFft/2` zero-pad, so mel frame m covers original
19+
// samples [m·hop − nFft/2, m·hop + nFft/2). Frame m is *frozen* — unaffected by
20+
// future audio — once `m·hop + nFft/2 <= bufferLen`. The session only feeds the
21+
// encoder frozen, whole chunks; the trailing partial chunk waits for `finish()`,
22+
// which reproduces the offline right-pad exactly.
23+
// * The cache-aware encoder + greedy RNN-T already carry all their state in
24+
// `NemotronASRStreamEncoderState` / `NemotronASRStreamRNNTState`, so resuming
25+
// across `step` calls reproduces the single-shot walk.
26+
//
27+
// Cost note: v1 recomputes the full mel from the raw buffer each `step` (O(buffer)).
28+
// Mel is ~1% of encode, so this is negligible at utterance scale; an incremental
29+
// mel over a sliding raw window is a future optimization.
30+
31+
/// Per-stream greedy RNN-T state, carried across chunks / `step` calls.
32+
final class NemotronASRStreamRNNTState {
33+
var results: [NemoAlignedToken] = []
34+
var lastToken: Int
35+
var decoderState: NemoLSTMState?
36+
var globalTime = 0 // absolute subsampled-frame index, for token timestamps
37+
38+
init(blankToken: Int) { lastToken = blankToken }
39+
}
40+
41+
extension NemotronASRModel {
42+
/// Greedy RNN-T over one chunk of prompted encoder frames `(1, c, d)`, mutating
43+
/// `state`. Lifted verbatim from the offline streaming loop so the one-shot and
44+
/// session paths share a single decoder (SSOT).
45+
func streamRNNTDecode(
46+
_ prompted: MLXArray,
47+
state: NemotronASRStreamRNNTState,
48+
frameSeconds: Double
49+
) {
50+
let chunkLen = prompted.shape[1]
51+
var time = 0
52+
var newSymbols = 0
53+
while time < chunkLen {
54+
let frame = prompted[0..., time..<(time + 1), 0...]
55+
let currentToken: MLXArray? = state.lastToken == blankTokenID
56+
? nil
57+
: MLXArray(Int32(state.lastToken)).reshaped([1, 1]).asType(.int32)
58+
let decoderOutput = decoder(currentToken, state: state.decoderState)
59+
let pred = decoderOutput.0.asType(frame.dtype)
60+
let proposedState: NemoLSTMState = (
61+
hidden: decoderOutput.1.hidden?.asType(frame.dtype),
62+
cell: decoderOutput.1.cell?.asType(frame.dtype)
63+
)
64+
let jointOutput = joint(frame, pred)
65+
let token = jointOutput.argMax(axis: -1).item(Int.self)
66+
let step = NemoDecodingLogic.rnntStep(
67+
predictedToken: token,
68+
blankToken: blankTokenID,
69+
time: time,
70+
newSymbols: newSymbols,
71+
maxSymbols: maxSymbols
72+
)
73+
if step.emittedToken {
74+
state.lastToken = token
75+
state.decoderState = proposedState
76+
if !NemotronASRTokenizer.isSpecialToken(token, vocabulary: vocabulary) {
77+
state.results.append(
78+
NemoAlignedToken(
79+
id: token,
80+
text: NemotronASRTokenizer.decode(tokens: [token], vocabulary: vocabulary),
81+
start: Double(state.globalTime + time) * frameSeconds,
82+
duration: frameSeconds
83+
)
84+
)
85+
}
86+
}
87+
time = step.nextTime
88+
newSymbols = step.nextNewSymbols
89+
}
90+
state.globalTime += chunkLen
91+
}
92+
}
93+
94+
public final class NemotronASRStreamSession {
95+
/// Text + token ids decoded by a single `step` / `finish` call.
96+
public struct Delta {
97+
public let text: String
98+
public let tokenIds: [Int]
99+
}
100+
101+
private let model: NemotronASRModel
102+
private let language: String?
103+
private let chunkFrames: Int?
104+
private let frameSeconds: Double
105+
106+
private var rawBuffer: [Float] = []
107+
private let encState: NemotronASRStreamEncoderState
108+
private let rnntState: NemotronASRStreamRNNTState
109+
private var emittedText = ""
110+
private var done = false
111+
112+
init(model: NemotronASRModel, language: String?, chunkFrames: Int?) {
113+
self.model = model
114+
self.language = language
115+
self.chunkFrames = chunkFrames
116+
self.encState = NemotronASRStreamEncoderState(layers: model.encoder.layers.count)
117+
self.rnntState = NemotronASRStreamRNNTState(blankToken: model.blankTokenID)
118+
self.frameSeconds = Double(model.encoderConfig.subsamplingFactor * model.preprocessConfig.hopLength)
119+
/ Double(model.preprocessConfig.sampleRate)
120+
let norm = model.preprocessConfig.normalize.lowercased()
121+
precondition(
122+
norm == "na" || norm == "none",
123+
"NemotronASRStreamSession requires NA mel normalization (got \(model.preprocessConfig.normalize)); "
124+
+ "per-utterance normalization is not frozen incrementally."
125+
)
126+
}
127+
128+
/// Full transcript decoded so far.
129+
public var text: String { emittedText }
130+
/// Token ids decoded so far.
131+
public var tokens: [Int] { rnntState.results.map { $0.id } }
132+
/// Whether `finish()` has been called.
133+
public var isFinished: Bool { done }
134+
135+
/// Ingest a chunk of 16 kHz mono samples; returns the text decoded by this call.
136+
@discardableResult
137+
public func step(_ samples: [Float]) -> Delta {
138+
rawBuffer.append(contentsOf: samples)
139+
return advance(final: false)
140+
}
141+
142+
@discardableResult
143+
public func step(_ samples: MLXArray) -> Delta {
144+
let mono = samples.ndim > 1 ? samples.mean(axis: -1) : samples
145+
return step(mono.asType(.float32).asArray(Float.self))
146+
}
147+
148+
/// Flush the trailing partial chunk so the final transcript equals
149+
/// `generateStream(wholeAudio)`. Call once after the last `step`.
150+
@discardableResult
151+
public func finish() -> Delta {
152+
advance(final: true)
153+
}
154+
155+
private func advance(final: Bool) -> Delta {
156+
guard !done else { return Delta(text: "", tokenIds: []) }
157+
guard !rawBuffer.isEmpty else {
158+
if final { done = true }
159+
return Delta(text: "", tokenIds: [])
160+
}
161+
162+
let audio = MLXArray(rawBuffer)
163+
let mel = NemotronASRAudio.logMelSpectrogram(audio, config: model.preprocessConfig) // (1, T, F)
164+
let totalMel = mel.shape[1]
165+
let limit = final ? totalMel : frozenMelFrames(totalMel: totalMel)
166+
167+
let firstNew = rnntState.results.count
168+
model.streamEncodeChunks(
169+
mel,
170+
language: language,
171+
limit: limit,
172+
chunkFrames: chunkFrames,
173+
flushTail: final,
174+
state: encState
175+
) { prompted in
176+
model.streamRNNTDecode(prompted, state: rnntState, frameSeconds: frameSeconds)
177+
}
178+
179+
// Bound the lazy graph across steps: materialize the caches the next step
180+
// resumes from (the in-loop `.item()` already forced this chunk's encoder).
181+
var live: [MLXArray] = []
182+
if let mc = encState.melCache { live.append(mc) }
183+
for c in encState.attnCache where c != nil { live.append(c!) }
184+
for c in encState.convCache where c != nil { live.append(c!) }
185+
if !live.isEmpty { MLX.asyncEval(live) }
186+
187+
let fullText = NemoAlignment.sentencesToResult(
188+
NemoAlignment.tokensToSentences(rnntState.results)
189+
).text
190+
let deltaText = fullText.hasPrefix(emittedText)
191+
? String(fullText.dropFirst(emittedText.count))
192+
: fullText
193+
emittedText = fullText
194+
let deltaIds = rnntState.results[firstNew...].map { $0.id }
195+
196+
if final { done = true }
197+
Memory.clearCache()
198+
return Delta(text: deltaText, tokenIds: Array(deltaIds))
199+
}
200+
201+
/// Number of mel frames whose STFT window is fully covered by real audio, hence
202+
/// bit-identical to the final offline mel regardless of future samples. The STFT
203+
/// centers with `nFft/2` zero-pad, so frame m is frozen iff m·hop + nFft/2 <=
204+
/// bufferLen. Conservative by construction: an under-count only delays a chunk
205+
/// by one `step` (latency), never corrupts output.
206+
private func frozenMelFrames(totalMel: Int) -> Int {
207+
let hop = model.preprocessConfig.hopLength
208+
let half = model.preprocessConfig.nFft / 2
209+
guard rawBuffer.count >= half else { return 0 }
210+
let largestFrozen = (rawBuffer.count - half) / hop
211+
return min(totalMel, largestFrozen + 1)
212+
}
213+
}
214+
215+
public extension NemotronASRModel {
216+
/// Create an online streaming session. Feed audio with `step(_:)`, then `finish()`.
217+
///
218+
/// - Parameters:
219+
/// - language: language prompt (e.g. "ru"); `nil` uses the model default.
220+
/// - chunkMs: chunk size in ms; maps to right-context per the latency ladder
221+
/// (80→[56,0], 160→[56,1], 320→[56,3], 560→[56,6], 1120→[56,13]). `nil` uses
222+
/// the model's native chunk (`default_att_context_size`). Smaller chunks cut
223+
/// latency at a modest WER cost — see the model card's chunk-size table.
224+
func makeStreamSession(
225+
language: String? = nil,
226+
chunkMs: Int? = nil
227+
) -> NemotronASRStreamSession {
228+
let chunkFrames: Int?
229+
if let chunkMs {
230+
let msPerSubframe = encoderConfig.subsamplingFactor * preprocessConfig.hopLength * 1000
231+
/ preprocessConfig.sampleRate // 8*160*1000/16000 = 80 ms
232+
chunkFrames = max(1, Int((Double(chunkMs) / Double(msPerSubframe)).rounded()))
233+
} else {
234+
chunkFrames = nil
235+
}
236+
return NemotronASRStreamSession(model: self, language: language, chunkFrames: chunkFrames)
237+
}
238+
239+
/// Transcribe a whole audio buffer through the online streaming session, feeding
240+
/// fixed `chunkMs`-sized chunks as a live caller would — instead of the whole-buffer
241+
/// `generateStream`. `onDelta` receives each newly decoded fragment as it is produced
242+
/// (use it to render live output); the returned `STTOutput` is the full transcript.
243+
func transcribeStreaming(
244+
audio: MLXArray,
245+
generationParameters: STTGenerateParameters = STTGenerateParameters(),
246+
chunkMs: Int = 480,
247+
onDelta: ((String) -> Void)? = nil
248+
) -> STTOutput {
249+
let mono = audio.ndim > 1 ? audio.mean(axis: -1) : audio
250+
let samples = mono.asType(.float32).asArray(Float.self)
251+
let chunk = max(1, 16000 * chunkMs / 1000)
252+
253+
let session = makeStreamSession(language: generationParameters.language)
254+
let start = CFAbsoluteTimeGetCurrent()
255+
256+
func emit(_ delta: NemotronASRStreamSession.Delta) {
257+
guard !delta.text.isEmpty else { return }
258+
onDelta?(delta.text)
259+
}
260+
var idx = 0
261+
while idx < samples.count {
262+
let end = min(idx + chunk, samples.count)
263+
emit(session.step(Array(samples[idx..<end])))
264+
idx = end
265+
}
266+
emit(session.finish())
267+
268+
let totalTime = CFAbsoluteTimeGetCurrent() - start
269+
let tokenCount = session.tokens.count
270+
return STTOutput(
271+
text: session.text.trimmingCharacters(in: .whitespacesAndNewlines),
272+
language: generationParameters.language,
273+
generationTokens: tokenCount,
274+
totalTokens: tokenCount,
275+
generationTps: totalTime > 0 ? Double(tokenCount) / totalTime : 0,
276+
totalTime: totalTime
277+
)
278+
}
279+
}

0 commit comments

Comments
 (0)