Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions Sources/MLXAudioSTT/Models/Parakeet/ParakeetCoreMLEncoder.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#if canImport(CoreML)
import CoreML
import Foundation
import MLX

/// Drop-in CoreML/ANE replacement for the MLX Conformer encoder. The model is
/// fixed-shape because ANE requires it (RangeDim → 0% residency), so each chunk's mel
/// is padded to `fixedFrames` and the output cropped back to the true subsampled length.
public final class ParakeetCoreMLEncoder: @unchecked Sendable {
private let model: MLModel
private let featIn: Int
private let fixedFrames: Int
private let subsamplingFactor: Int
private let inputName: String
private let outputName: String

public init(
modelURL: URL,
featIn: Int,
fixedFrames: Int,
subsamplingFactor: Int,
computeUnits: MLComputeUnits = .all,
inputName: String = "features",
outputName: String = "encoded"
) throws {
let compiledURL: URL
if modelURL.pathExtension == "mlmodelc" {
compiledURL = modelURL
} else {
compiledURL = try MLModel.compileModel(at: modelURL)
}
let config = MLModelConfiguration()
config.computeUnits = computeUnits
self.model = try MLModel(contentsOf: compiledURL, configuration: config)
self.featIn = featIn
self.fixedFrames = fixedFrames
self.subsamplingFactor = subsamplingFactor
self.inputName = inputName
self.outputName = outputName
}

/// Matches `ParakeetModel.computeEncodedLengths`: `floor((L-1)/2)+1`, log2(factor) times.
static func subsampledLength(frames: Int, subsamplingFactor: Int) -> Int {
var l = frames
let steps = Int(log2(Double(subsamplingFactor)))
for _ in 0..<steps { l = (l - 1) / 2 + 1 }
return l
}

private func encodedLength(for frames: Int) -> Int {
Self.subsampledLength(frames: frames, subsamplingFactor: subsamplingFactor)
}

/// Encode one chunk. `features`: `[1, T, featIn]` (any float dtype).
/// Returns `(encoded [1, T', dModel], lengths [1])`, dtype = `outputDType`.
public func encode(_ features: MLXArray, outputDType: DType) throws -> (MLXArray, MLXArray) {
let trueFrames = features.shape[1]
let clamped = min(trueFrames, fixedFrames)

var mel = features.asType(.float32)
if trueFrames < fixedFrames {
mel = padded(mel, widths: [.init((0, 0)), .init((0, fixedFrames - trueFrames)), .init((0, 0))])
} else if trueFrames > fixedFrames {
mel = mel[0..., 0..<fixedFrames, 0...]
}
// mel is [1, fixedFrames, featIn] row-major; CoreML wants [1, featIn, fixedFrames].
let melFlat = mel.asArray(Float.self) // index = t * featIn + f
let input = try MLMultiArray(shape: [1, NSNumber(value: featIn), NSNumber(value: fixedFrames)], dataType: .float32)
input.dataPointer.withMemoryRebound(to: Float.self, capacity: featIn * fixedFrames) { dst in
for t in 0..<fixedFrames {
for f in 0..<featIn {
dst[f * fixedFrames + t] = melFlat[t * featIn + f]
}
}
}

let provider = try MLDictionaryFeatureProvider(dictionary: [inputName: MLFeatureValue(multiArray: input)])
let out = try model.prediction(from: provider)
guard let enc = out.featureValue(for: outputName)?.multiArrayValue else {
throw STTError.invalidInput("CoreML encoder produced no '\(outputName)' output")
}

let dModel = enc.shape[1].intValue
let tFull = enc.shape[2].intValue
// ANE outputs are often stride-padded, so honor strides rather than reading the
// raw buffer sequentially (which would scramble frames).
let s1 = enc.strides[1].intValue
let s2 = enc.strides[2].intValue
let count = dModel * tFull
let capacity = (dModel - 1) * s1 + (tFull - 1) * s2 + 1
var encFloats = [Float](repeating: 0, count: count) // packed [d * tFull + t]
if enc.dataType == .float16 {
let p = enc.dataPointer.bindMemory(to: UInt16.self, capacity: capacity)
for d in 0..<dModel { for t in 0..<tFull { encFloats[d * tFull + t] = Float(Float16(bitPattern: p[d * s1 + t * s2])) } }
} else {
let p = enc.dataPointer.bindMemory(to: Float.self, capacity: capacity)
for d in 0..<dModel { for t in 0..<tFull { encFloats[d * tFull + t] = p[d * s1 + t * s2] } }
}

let validLen = encodedLength(for: clamped)
var encoded = MLXArray(encFloats, [1, dModel, tFull]).transposed(0, 2, 1)
if validLen < tFull {
encoded = encoded[0..., 0..<validLen, 0...]
}
let lengths = MLXArray([Int32(validLen)]).asType(.int32)
return (encoded.asType(outputDType), lengths)
}
}
#endif
95 changes: 93 additions & 2 deletions Sources/MLXAudioSTT/Models/Parakeet/ParakeetModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public final class ParakeetModel: Module, STTGenerationModel {
enum EncoderExecutionImplementation: Sendable {
case plain
case compiled
case coreML
}

struct TDTTraceStep: Sendable, Equatable {
Expand All @@ -52,6 +53,9 @@ public final class ParakeetModel: Module, STTGenerationModel {

var tdtDecoderImplementation: TDTDecoderImplementation?
var encoderExecutionImplementation: EncoderExecutionImplementation?
#if canImport(CoreML)
var coreMLEncoder: ParakeetCoreMLEncoder?
#endif
var tdtTraceEmitter: (@Sendable (TDTTraceStep) -> Void)?
private var compiledEncoderFeaturesByShape: [String: @Sendable (MLXArray) -> MLXArray] = [:]

Expand Down Expand Up @@ -313,9 +317,93 @@ public final class ParakeetModel: Module, STTGenerationModel {
let encodedFeatures = compiledEncoderFeatures(for: features)(features)
let encodedLengths = computeEncodedLengths(from: resolvedLengths)
return (encodedFeatures, encodedLengths)
case .coreML:
#if canImport(CoreML)
if let coreMLEncoder,
let result = try? coreMLEncoder.encode(features, outputDType: computeDType) {
return result
}
#endif
return encoder(features, lengths: resolvedLengths) // fallback if CoreML unavailable
}
}

/// How to source the optional CoreML/ANE Conformer encoder (default `.off` = pure MLX).
public enum ANEEncoder: Sendable {
case off
case on // download `defaultANEEncoderRepo` from Hugging Face
case repo(String) // download a specific Hugging Face repo
case package(URL) // a local .mlpackage / .mlmodelc
}

public static let defaultANEEncoderRepo = "beshkenadze/parakeet-tdt-0.6b-v3-coreml-ane"

/// Apply an `ANEEncoder` option. No-op for `.off` or when CoreML is unavailable.
public func applyANEEncoder(_ option: ANEEncoder, cache: HubCache = .default) async throws {
#if canImport(CoreML)
switch option {
case .off: break
case .on: try await enableCoreMLEncoder(repo: Self.defaultANEEncoderRepo, cache: cache)
case .repo(let repo): try await enableCoreMLEncoder(repo: repo, cache: cache)
case .package(let url): try enableCoreMLEncoder(modelURL: url)
}
#endif
}

#if canImport(CoreML)
/// Route the Conformer encoder through CoreML/ANE; decoder and chunking stay in MLX.
public func enableCoreMLEncoder(modelURL: URL, fixedFrames: Int = 1000) throws {
coreMLEncoder = try ParakeetCoreMLEncoder(
modelURL: modelURL,
featIn: encoderConfig.featIn,
fixedFrames: fixedFrames,
subsamplingFactor: encoderConfig.subsamplingFactor
)
encoderExecutionImplementation = .coreML
}

/// Download a CoreML encoder `.mlpackage` from a Hugging Face repo, then route through it.
public func enableCoreMLEncoder(repo: String, cache: HubCache = .default) async throws {
let url = try await Self.downloadANEEncoderPackage(repo: repo, cache: cache)
try enableCoreMLEncoder(modelURL: url)
}

static func downloadANEEncoderPackage(repo: String, cache: HubCache = .default) async throws -> URL {
guard let repoID = Repo.ID(rawValue: repo) else {
throw NSError(domain: "ParakeetModel", code: 2,
userInfo: [NSLocalizedDescriptionKey: "Invalid ANE encoder repo: \(repo)"])
}
let hfToken = ProcessInfo.processInfo.environment["HF_TOKEN"]
?? Bundle.main.object(forInfoDictionaryKey: "HF_TOKEN") as? String
let client = (hfToken?.isEmpty == false)
? HubClient(host: HubClient.defaultHost, bearerToken: hfToken!, cache: cache)
: HubClient(cache: cache)
let dir = (client.cache ?? cache).cacheDirectory
.appendingPathComponent("mlx-audio")
.appendingPathComponent(repo.replacingOccurrences(of: "/", with: "_"))

if let cached = findEncoderPackage(in: dir) { return cached }

try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true)
_ = try await client.downloadSnapshot(
of: repoID, kind: .model, to: dir, revision: "main",
matching: ["*.json", "*.mlmodel", "*.bin", "*.weights", "*.mil", "*.espresso.*"],
progressHandler: { _ in }
)
guard let pkg = findEncoderPackage(in: dir) else {
throw NSError(domain: "ParakeetModel", code: 3,
userInfo: [NSLocalizedDescriptionKey: "No .mlpackage/.mlmodelc found in \(repo)"])
}
return pkg
}

static func findEncoderPackage(in dir: URL) -> URL? {
guard let items = try? FileManager.default.contentsOfDirectory(at: dir, includingPropertiesForKeys: nil)
else { return nil }
return items.first { ["mlpackage", "mlmodelc"].contains($0.pathExtension) }
}
#endif

func compiledEncoderFeatures(for features: MLXArray) -> @Sendable (MLXArray) -> MLXArray {
let key = "\(features.shape)-\(features.dtype)"
if let compiled = compiledEncoderFeaturesByShape[key] {
Expand Down Expand Up @@ -1054,7 +1142,8 @@ public extension ParakeetModel {
static func fromPretrained(
_ modelPath: String,
computeDType: DType = .bfloat16,
cache: HubCache = .default
cache: HubCache = .default,
aneEncoder: ANEEncoder = .off
) async throws -> ParakeetModel {
let hfToken: String? = ProcessInfo.processInfo.environment["HF_TOKEN"]
?? Bundle.main.object(forInfoDictionaryKey: "HF_TOKEN") as? String
Expand All @@ -1073,7 +1162,9 @@ public extension ParakeetModel {
hfToken: hfToken,
cache: cache
)
return try fromDirectory(modelDir, computeDType: computeDType)
let model = try fromDirectory(modelDir, computeDType: computeDType)
try await model.applyANEEncoder(aneEncoder, cache: cache)
return model
}
}

Expand Down
21 changes: 21 additions & 0 deletions Sources/Tools/mlx-audio-swift-stt/App.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ private struct Options {
var prefillStepSize = 2048
var genKwargsRaw: String? = nil
var text = ""
var coremlEncoder: String? = nil
var ane = false

var temperature: Float? = nil
var topP: Float? = nil
Expand Down Expand Up @@ -139,6 +141,11 @@ private struct Options {
case "--text":
guard let v = it.next() else { throw CLIError.missingValue(arg) }
options.text = v
case "--coreml-encoder":
guard let v = it.next() else { throw CLIError.missingValue(arg) }
options.coremlEncoder = v
case "--ane":
options.ane = true
case "--help", "-h":
printUsage()
exit(0)
Expand Down Expand Up @@ -271,6 +278,18 @@ enum App {
let (inputSampleRate, inputAudio) = try loadAudioArray(from: inputURL)
let audio = try prepareAudioForSTT(inputAudio, inputSampleRate: inputSampleRate, targetSampleRate: 16000)

#if canImport(CoreML)
if case .stt(let m) = model, let parakeet = m as? ParakeetModel {
if let coremlPath = options.coremlEncoder {
try parakeet.enableCoreMLEncoder(modelURL: resolveURL(path: coremlPath))
if options.verbose { print("CoreML/ANE encoder enabled: \(coremlPath)") }
} else if options.ane {
try await parakeet.enableCoreMLEncoder(repo: ParakeetModel.defaultANEEncoderRepo)
if options.verbose { print("ANE encoder enabled: \(ParakeetModel.defaultANEEncoderRepo)") }
}
}
#endif

let startTime = CFAbsoluteTimeGetCurrent()

if options.verbose {
Expand Down Expand Up @@ -346,9 +365,11 @@ enum App {

if options.verbose {
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
let audioSeconds = Double(audio.size) / 16000.0
print("\n==========")
print("Saved file to: \(options.outputPath!).\(options.format.rawValue)")
print(String(format: "Processing time: %.2f seconds", elapsed))
print(String(format: "RTF: %.1fx realtime (%.1fs audio / %.2fs)", audioSeconds / elapsed, audioSeconds, elapsed))
print(String(format: "Prompt: %d tokens, %.3f tokens-per-sec", output.promptTokens, output.promptTps))
print(String(format: "Generation: %d tokens, %.3f tokens-per-sec", output.generationTokens, output.generationTps))
print(String(format: "Peak memory: %.2f GB", output.peakMemoryUsage))
Expand Down
41 changes: 41 additions & 0 deletions Tests/ParakeetCoreMLEncoderTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#if canImport(CoreML)
import Foundation
import Testing

@testable import MLXAudioSTT

@Suite("Parakeet CoreML Encoder Tests")
struct ParakeetCoreMLEncoderTests {
/// The wrapper's output-length math must match `ParakeetModel.computeEncodedLengths`
/// (NeMo dw-striding: `floor((L-1)/2)+1`, log2(factor) times).
@Test func subsampledLengthMatchesDwStriding() {
#expect(ParakeetCoreMLEncoder.subsampledLength(frames: 1000, subsamplingFactor: 8) == 125)
#expect(ParakeetCoreMLEncoder.subsampledLength(frames: 995, subsamplingFactor: 8) == 125)
#expect(ParakeetCoreMLEncoder.subsampledLength(frames: 128, subsamplingFactor: 8) == 16)
#expect(ParakeetCoreMLEncoder.subsampledLength(frames: 1, subsamplingFactor: 8) == 1)
}

/// A missing/invalid `.mlpackage` must surface as a thrown error (the model then falls
/// back to the MLX encoder), never a crash.
@Test func throwsOnMissingModel() {
let bogus = URL(fileURLWithPath: "/nonexistent/parakeet_enc.mlpackage")
#expect(throws: (any Error).self) {
_ = try ParakeetCoreMLEncoder(
modelURL: bogus, featIn: 128, fixedFrames: 1000, subsamplingFactor: 8)
}
}

/// Resolving the downloaded encoder picks the `.mlpackage` (or `.mlmodelc`) directory.
@Test func findsEncoderPackageInDirectory() throws {
let fm = FileManager.default
let base = fm.temporaryDirectory.appendingPathComponent("parakeet-coreml-findtest")
try? fm.removeItem(at: base)
let pkg = base.appendingPathComponent("enc.mlpackage")
try fm.createDirectory(at: pkg, withIntermediateDirectories: true)
defer { try? fm.removeItem(at: base) }

#expect(ParakeetModel.findEncoderPackage(in: base)?.lastPathComponent == "enc.mlpackage")
#expect(ParakeetModel.findEncoderPackage(in: pkg) == nil) // empty dir → nothing
}
}
#endif
Loading