|
| 1 | +import Foundation |
| 2 | +import MLX |
| 3 | +import MLXAudioCore |
| 4 | +import MLXNN |
| 5 | + |
| 6 | +// True incremental (online) streaming for Nemotron 3.5 ASR. |
| 7 | +// |
| 8 | +// The offline `generateStream(...)` computes the mel of the whole utterance up |
| 9 | +// front and only then walks the cache-aware encoder. A live caller (e.g. a mic |
| 10 | +// feeding 80 ms chunks) instead wants to push audio as it arrives and read text |
| 11 | +// back with the model's native chunk delay. This session does exactly that, while |
| 12 | +// staying bit-identical to the offline encode. |
| 13 | +// |
| 14 | +// Why it is bit-identical (transcript == generateStream(wholeAudio)): |
| 15 | +// * The preprocessor uses `normalize: "NA"` (verified on the shipped checkpoint), |
| 16 | +// so each mel frame is an independent function of a fixed sample window — no |
| 17 | +// per-utterance mean/std that would shift as audio grows. preemph is causal. |
| 18 | +// * The STFT centers with `nFft/2` zero-pad, so mel frame m covers original |
| 19 | +// samples [m·hop − nFft/2, m·hop + nFft/2). Frame m is *frozen* — unaffected by |
| 20 | +// future audio — once `m·hop + nFft/2 <= bufferLen`. The session only feeds the |
| 21 | +// encoder frozen, whole chunks; the trailing partial chunk waits for `finish()`, |
| 22 | +// which reproduces the offline right-pad exactly. |
| 23 | +// * The cache-aware encoder + greedy RNN-T already carry all their state in |
| 24 | +// `NemotronASRStreamEncoderState` / `NemotronASRStreamRNNTState`, so resuming |
| 25 | +// across `step` calls reproduces the single-shot walk. |
| 26 | +// |
| 27 | +// Cost note: v1 recomputes the full mel from the raw buffer each `step` (O(buffer)). |
| 28 | +// Mel is ~1% of encode, so this is negligible at utterance scale; an incremental |
| 29 | +// mel over a sliding raw window is a future optimization. |
| 30 | + |
| 31 | +/// Per-stream greedy RNN-T state, carried across chunks / `step` calls. |
| 32 | +final class NemotronASRStreamRNNTState { |
| 33 | + var results: [NemoAlignedToken] = [] |
| 34 | + var lastToken: Int |
| 35 | + var decoderState: NemoLSTMState? |
| 36 | + var globalTime = 0 // absolute subsampled-frame index, for token timestamps |
| 37 | + |
| 38 | + init(blankToken: Int) { lastToken = blankToken } |
| 39 | +} |
| 40 | + |
| 41 | +extension NemotronASRModel { |
| 42 | + /// Greedy RNN-T over one chunk of prompted encoder frames `(1, c, d)`, mutating |
| 43 | + /// `state`. Lifted verbatim from the offline streaming loop so the one-shot and |
| 44 | + /// session paths share a single decoder (SSOT). |
| 45 | + func streamRNNTDecode( |
| 46 | + _ prompted: MLXArray, |
| 47 | + state: NemotronASRStreamRNNTState, |
| 48 | + frameSeconds: Double |
| 49 | + ) { |
| 50 | + let chunkLen = prompted.shape[1] |
| 51 | + var time = 0 |
| 52 | + var newSymbols = 0 |
| 53 | + while time < chunkLen { |
| 54 | + let frame = prompted[0..., time..<(time + 1), 0...] |
| 55 | + let currentToken: MLXArray? = state.lastToken == blankTokenID |
| 56 | + ? nil |
| 57 | + : MLXArray(Int32(state.lastToken)).reshaped([1, 1]).asType(.int32) |
| 58 | + let decoderOutput = decoder(currentToken, state: state.decoderState) |
| 59 | + let pred = decoderOutput.0.asType(frame.dtype) |
| 60 | + let proposedState: NemoLSTMState = ( |
| 61 | + hidden: decoderOutput.1.hidden?.asType(frame.dtype), |
| 62 | + cell: decoderOutput.1.cell?.asType(frame.dtype) |
| 63 | + ) |
| 64 | + let jointOutput = joint(frame, pred) |
| 65 | + let token = jointOutput.argMax(axis: -1).item(Int.self) |
| 66 | + let step = NemoDecodingLogic.rnntStep( |
| 67 | + predictedToken: token, |
| 68 | + blankToken: blankTokenID, |
| 69 | + time: time, |
| 70 | + newSymbols: newSymbols, |
| 71 | + maxSymbols: maxSymbols |
| 72 | + ) |
| 73 | + if step.emittedToken { |
| 74 | + state.lastToken = token |
| 75 | + state.decoderState = proposedState |
| 76 | + if !NemotronASRTokenizer.isSpecialToken(token, vocabulary: vocabulary) { |
| 77 | + state.results.append( |
| 78 | + NemoAlignedToken( |
| 79 | + id: token, |
| 80 | + text: NemotronASRTokenizer.decode(tokens: [token], vocabulary: vocabulary), |
| 81 | + start: Double(state.globalTime + time) * frameSeconds, |
| 82 | + duration: frameSeconds |
| 83 | + ) |
| 84 | + ) |
| 85 | + } |
| 86 | + } |
| 87 | + time = step.nextTime |
| 88 | + newSymbols = step.nextNewSymbols |
| 89 | + } |
| 90 | + state.globalTime += chunkLen |
| 91 | + } |
| 92 | +} |
| 93 | + |
| 94 | +public final class NemotronASRStreamSession { |
| 95 | + /// Text + token ids decoded by a single `step` / `finish` call. |
| 96 | + public struct Delta { |
| 97 | + public let text: String |
| 98 | + public let tokenIds: [Int] |
| 99 | + } |
| 100 | + |
| 101 | + private let model: NemotronASRModel |
| 102 | + private let language: String? |
| 103 | + private let chunkFrames: Int? |
| 104 | + private let frameSeconds: Double |
| 105 | + |
| 106 | + private var rawBuffer: [Float] = [] |
| 107 | + private let encState: NemotronASRStreamEncoderState |
| 108 | + private let rnntState: NemotronASRStreamRNNTState |
| 109 | + private var emittedText = "" |
| 110 | + private var done = false |
| 111 | + |
| 112 | + init(model: NemotronASRModel, language: String?, chunkFrames: Int?) { |
| 113 | + self.model = model |
| 114 | + self.language = language |
| 115 | + self.chunkFrames = chunkFrames |
| 116 | + self.encState = NemotronASRStreamEncoderState(layers: model.encoder.layers.count) |
| 117 | + self.rnntState = NemotronASRStreamRNNTState(blankToken: model.blankTokenID) |
| 118 | + self.frameSeconds = Double(model.encoderConfig.subsamplingFactor * model.preprocessConfig.hopLength) |
| 119 | + / Double(model.preprocessConfig.sampleRate) |
| 120 | + let norm = model.preprocessConfig.normalize.lowercased() |
| 121 | + precondition( |
| 122 | + norm == "na" || norm == "none", |
| 123 | + "NemotronASRStreamSession requires NA mel normalization (got \(model.preprocessConfig.normalize)); " |
| 124 | + + "per-utterance normalization is not frozen incrementally." |
| 125 | + ) |
| 126 | + } |
| 127 | + |
| 128 | + /// Full transcript decoded so far. |
| 129 | + public var text: String { emittedText } |
| 130 | + /// Token ids decoded so far. |
| 131 | + public var tokens: [Int] { rnntState.results.map { $0.id } } |
| 132 | + /// Whether `finish()` has been called. |
| 133 | + public var isFinished: Bool { done } |
| 134 | + |
| 135 | + /// Ingest a chunk of 16 kHz mono samples; returns the text decoded by this call. |
| 136 | + @discardableResult |
| 137 | + public func step(_ samples: [Float]) -> Delta { |
| 138 | + rawBuffer.append(contentsOf: samples) |
| 139 | + return advance(final: false) |
| 140 | + } |
| 141 | + |
| 142 | + @discardableResult |
| 143 | + public func step(_ samples: MLXArray) -> Delta { |
| 144 | + let mono = samples.ndim > 1 ? samples.mean(axis: -1) : samples |
| 145 | + return step(mono.asType(.float32).asArray(Float.self)) |
| 146 | + } |
| 147 | + |
| 148 | + /// Flush the trailing partial chunk so the final transcript equals |
| 149 | + /// `generateStream(wholeAudio)`. Call once after the last `step`. |
| 150 | + @discardableResult |
| 151 | + public func finish() -> Delta { |
| 152 | + advance(final: true) |
| 153 | + } |
| 154 | + |
| 155 | + private func advance(final: Bool) -> Delta { |
| 156 | + guard !done else { return Delta(text: "", tokenIds: []) } |
| 157 | + guard !rawBuffer.isEmpty else { |
| 158 | + if final { done = true } |
| 159 | + return Delta(text: "", tokenIds: []) |
| 160 | + } |
| 161 | + |
| 162 | + let audio = MLXArray(rawBuffer) |
| 163 | + let mel = NemotronASRAudio.logMelSpectrogram(audio, config: model.preprocessConfig) // (1, T, F) |
| 164 | + let totalMel = mel.shape[1] |
| 165 | + let limit = final ? totalMel : frozenMelFrames(totalMel: totalMel) |
| 166 | + |
| 167 | + let firstNew = rnntState.results.count |
| 168 | + model.streamEncodeChunks( |
| 169 | + mel, |
| 170 | + language: language, |
| 171 | + limit: limit, |
| 172 | + chunkFrames: chunkFrames, |
| 173 | + flushTail: final, |
| 174 | + state: encState |
| 175 | + ) { prompted in |
| 176 | + model.streamRNNTDecode(prompted, state: rnntState, frameSeconds: frameSeconds) |
| 177 | + } |
| 178 | + |
| 179 | + // Bound the lazy graph across steps: materialize the caches the next step |
| 180 | + // resumes from (the in-loop `.item()` already forced this chunk's encoder). |
| 181 | + var live: [MLXArray] = [] |
| 182 | + if let mc = encState.melCache { live.append(mc) } |
| 183 | + for c in encState.attnCache where c != nil { live.append(c!) } |
| 184 | + for c in encState.convCache where c != nil { live.append(c!) } |
| 185 | + if !live.isEmpty { MLX.asyncEval(live) } |
| 186 | + |
| 187 | + let fullText = NemoAlignment.sentencesToResult( |
| 188 | + NemoAlignment.tokensToSentences(rnntState.results) |
| 189 | + ).text |
| 190 | + let deltaText = fullText.hasPrefix(emittedText) |
| 191 | + ? String(fullText.dropFirst(emittedText.count)) |
| 192 | + : fullText |
| 193 | + emittedText = fullText |
| 194 | + let deltaIds = rnntState.results[firstNew...].map { $0.id } |
| 195 | + |
| 196 | + if final { done = true } |
| 197 | + Memory.clearCache() |
| 198 | + return Delta(text: deltaText, tokenIds: Array(deltaIds)) |
| 199 | + } |
| 200 | + |
| 201 | + /// Number of mel frames whose STFT window is fully covered by real audio, hence |
| 202 | + /// bit-identical to the final offline mel regardless of future samples. The STFT |
| 203 | + /// centers with `nFft/2` zero-pad, so frame m is frozen iff m·hop + nFft/2 <= |
| 204 | + /// bufferLen. Conservative by construction: an under-count only delays a chunk |
| 205 | + /// by one `step` (latency), never corrupts output. |
| 206 | + private func frozenMelFrames(totalMel: Int) -> Int { |
| 207 | + let hop = model.preprocessConfig.hopLength |
| 208 | + let half = model.preprocessConfig.nFft / 2 |
| 209 | + guard rawBuffer.count >= half else { return 0 } |
| 210 | + let largestFrozen = (rawBuffer.count - half) / hop |
| 211 | + return min(totalMel, largestFrozen + 1) |
| 212 | + } |
| 213 | +} |
| 214 | + |
| 215 | +public extension NemotronASRModel { |
| 216 | + /// Create an online streaming session. Feed audio with `step(_:)`, then `finish()`. |
| 217 | + /// |
| 218 | + /// - Parameters: |
| 219 | + /// - language: language prompt (e.g. "ru"); `nil` uses the model default. |
| 220 | + /// - chunkMs: chunk size in ms; maps to right-context per the latency ladder |
| 221 | + /// (80→[56,0], 160→[56,1], 320→[56,3], 560→[56,6], 1120→[56,13]). `nil` uses |
| 222 | + /// the model's native chunk (`default_att_context_size`). Smaller chunks cut |
| 223 | + /// latency at a modest WER cost — see the model card's chunk-size table. |
| 224 | + func makeStreamSession( |
| 225 | + language: String? = nil, |
| 226 | + chunkMs: Int? = nil |
| 227 | + ) -> NemotronASRStreamSession { |
| 228 | + let chunkFrames: Int? |
| 229 | + if let chunkMs { |
| 230 | + let msPerSubframe = encoderConfig.subsamplingFactor * preprocessConfig.hopLength * 1000 |
| 231 | + / preprocessConfig.sampleRate // 8*160*1000/16000 = 80 ms |
| 232 | + chunkFrames = max(1, Int((Double(chunkMs) / Double(msPerSubframe)).rounded())) |
| 233 | + } else { |
| 234 | + chunkFrames = nil |
| 235 | + } |
| 236 | + return NemotronASRStreamSession(model: self, language: language, chunkFrames: chunkFrames) |
| 237 | + } |
| 238 | + |
| 239 | + /// Transcribe a whole audio buffer through the online streaming session, feeding |
| 240 | + /// fixed `chunkMs`-sized chunks as a live caller would — instead of the whole-buffer |
| 241 | + /// `generateStream`. `onDelta` receives each newly decoded fragment as it is produced |
| 242 | + /// (use it to render live output); the returned `STTOutput` is the full transcript. |
| 243 | + func transcribeStreaming( |
| 244 | + audio: MLXArray, |
| 245 | + generationParameters: STTGenerateParameters = STTGenerateParameters(), |
| 246 | + chunkMs: Int = 480, |
| 247 | + onDelta: ((String) -> Void)? = nil |
| 248 | + ) -> STTOutput { |
| 249 | + let mono = audio.ndim > 1 ? audio.mean(axis: -1) : audio |
| 250 | + let samples = mono.asType(.float32).asArray(Float.self) |
| 251 | + let chunk = max(1, 16000 * chunkMs / 1000) |
| 252 | + |
| 253 | + let session = makeStreamSession(language: generationParameters.language) |
| 254 | + let start = CFAbsoluteTimeGetCurrent() |
| 255 | + |
| 256 | + func emit(_ delta: NemotronASRStreamSession.Delta) { |
| 257 | + guard !delta.text.isEmpty else { return } |
| 258 | + onDelta?(delta.text) |
| 259 | + } |
| 260 | + var idx = 0 |
| 261 | + while idx < samples.count { |
| 262 | + let end = min(idx + chunk, samples.count) |
| 263 | + emit(session.step(Array(samples[idx..<end]))) |
| 264 | + idx = end |
| 265 | + } |
| 266 | + emit(session.finish()) |
| 267 | + |
| 268 | + let totalTime = CFAbsoluteTimeGetCurrent() - start |
| 269 | + let tokenCount = session.tokens.count |
| 270 | + return STTOutput( |
| 271 | + text: session.text.trimmingCharacters(in: .whitespacesAndNewlines), |
| 272 | + language: generationParameters.language, |
| 273 | + generationTokens: tokenCount, |
| 274 | + totalTokens: tokenCount, |
| 275 | + generationTps: totalTime > 0 ? Double(tokenCount) / totalTime : 0, |
| 276 | + totalTime: totalTime |
| 277 | + ) |
| 278 | + } |
| 279 | +} |
0 commit comments