Skip to content

Commit d1b1478

Browse files
Decouple from tokenizer and downloader packages (#118)
* Decouple from tokenizer and downloader packages MLX Swift LM currently has two fundamental problems: - Model loading is tightly coupled to the Hugging Face Hub. A Hub client is required even when loading models from a local directory. - Model loading performance with Swift Transformers lags far behind the Python equivalent, typically taking several seconds in Swift versus a few hundred milliseconds in Python. This PR implements the following solutions: - Swift Transformers is replaced with Swift Tokenizers, a streamlined and optimized fork that focuses purely on tokenizer functionality, with no Hugging Face dependency and no extraneous Core ML code. This unlocks a 10x to 15x speedup in model loading times. - The Downloader protocol abstracts away the model hosting provider, making it easy to use other providers such as ModelScope or define custom providers such as downloading from storage buckets. - Swift Hugging Face, a dedicated client for the Hub, is used in an optional module. No Hugging Face Hub code is bundled for users who don't need it. The `hub` parameter (previously `HubApi`) has been replaced with `from` (any `Downloader` or `URL` for a local directory). Functions that previously defaulted to `defaultHubApi` no longer have a default – callers must either pass a `Downloader` explicitly or use the convenience methods in `MLXLMHuggingFace` / `MLXEmbeddersHuggingFace`, which default to `HubClient.default`. For most users who were using the default Hub client, adding `import MLXLMHuggingFace` or `import MLXEmbeddersHuggingFace` and using the convenience overloads is sufficient. Users who were passing a custom `HubApi` instance should create a `HubClient` instead and pass it as the `from` parameter. `HubClient` conforms to `Downloader` via `MLXLMHuggingFace`. - `tokenizerId` and `overrideTokenizer` have been replaced by `tokenizerSource: TokenizerSource?`, which supports `.id(String)` for remote sources and `.directory(URL)` for local paths. - `preparePrompt` has been removed. This shouldn't be used anyway, since support for chat templates is available. - `modelDirectory(hub:)` has been removed. For local directories, pass the `URL` directly to the loading functions. For remote models, the `Downloader` protocol handles resolution. `loadTokenizer(configuration:hub:)` has been removed. Tokenizer loading now uses `AutoTokenizer.from(directory:)` from Swift Tokenizers directly. `replacementTokenizers` (the `TokenizerReplacementRegistry`) has been removed. Use `AutoTokenizer.register(_:for:)` from Swift Tokenizers instead. The `defaultHubApi` global has been removed. Hugging Face Hub access is now provided by `HubClient.default` from the `HuggingFace` module. - `downloadModel(hub:configuration:progressHandler:)` → `Downloader.download(id:revision:matching:useLatest:progressHandler:)` - `loadTokenizerConfig(configuration:hub:)` → `AutoTokenizer.from(directory:)` - `ModelFactory._load(hub:configuration:progressHandler:)` → `_load(configuration: ResolvedModelConfiguration)` - `ModelFactory._loadContainer`: removed (base `loadContainer` now builds the container from `_load`) --------- Co-authored-by: David Koski <dkoski@apple.com>
1 parent 25b00d4 commit d1b1478

File tree

79 files changed

+2927
-1687
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+2927
-1687
lines changed

Libraries/BenchmarkHelpers/BenchmarkHelpers.swift

Lines changed: 464 additions & 0 deletions
Large diffs are not rendered by default.

Libraries/IntegrationTestHelpers/IntegrationTestHelpers.swift

Lines changed: 614 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Integration Test Helpers
2+
3+
`IntegrationTestHelpers` and `BenchmarkHelpers` provide shared test logic for verifying end-to-end model loading, inference, tokenizer performance, and download performance. They are designed to be used by integration packages that supply their own `Downloader` and `TokenizerLoader` implementations.
4+
5+
## Integration packages
6+
7+
- [Swift Tokenizers MLX](https://github.com/DePasqualeOrg/swift-tokenizers-mlx): Uses [Swift Tokenizers](https://github.com/DePasqualeOrg/swift-tokenizers) and [Swift HF API](https://github.com/DePasqualeOrg/swift-hf-api)
8+
- [Swift Transformers MLX](https://github.com/DePasqualeOrg/swift-transformers-mlx): Uses [Swift Transformers](https://github.com/huggingface/swift-transformers) and [Swift Hugging Face](https://github.com/huggingface/swift-huggingface)
9+
10+
Integration tests and benchmarks are run from those packages.

Libraries/MLXEmbedders/EmbeddingModel.swift

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
// Copyright © 2024 Apple Inc.
22

33
import Foundation
4-
@preconcurrency import Hub
54
import MLX
5+
import MLXLMCommon
66
import MLXNN
7-
import Tokenizers
87

98
/// Container for models that guarantees single threaded access.
109
///
@@ -44,23 +43,21 @@ public actor ModelContainer {
4443
self.pooler = pooler
4544
}
4645

47-
/// build the model and tokenizer without passing non-sendable data over isolation barriers
46+
/// Build the model and tokenizer without passing non-sendable data over isolation barriers
4847
public init(
49-
hub: HubApi,
5048
modelDirectory: URL,
51-
configuration: ModelConfiguration
49+
tokenizerDirectory: URL,
50+
configuration: ModelConfiguration,
51+
tokenizerLoader: any TokenizerLoader
5252
) async throws {
53-
// Load tokenizer config and model in parallel using async let.
54-
async let tokenizerConfigTask = loadTokenizerConfig(
55-
configuration: configuration, hub: hub)
53+
// Load tokenizer and model in parallel
54+
async let tokenizerTask = tokenizerLoader.load(from: tokenizerDirectory)
5655

5756
self.model = try loadSynchronous(
5857
modelDirectory: modelDirectory, modelName: configuration.name)
5958
self.pooler = loadPooling(modelDirectory: modelDirectory, model: model)
6059

61-
let (tokenizerConfig, tokenizerData) = try await tokenizerConfigTask
62-
self.tokenizer = try PreTrainedTokenizer(
63-
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
60+
self.tokenizer = try await tokenizerTask
6461
}
6562

6663
/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as

Libraries/MLXEmbedders/Load.swift

Lines changed: 97 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
// Copyright © 2024 Apple Inc.
22

33
import Foundation
4-
@preconcurrency import Hub
54
import MLX
65
import MLXLMCommon
76
import MLXNN
8-
import Tokenizers
97

108
/// Errors encountered during the model loading and initialization process.
119
///
@@ -26,9 +24,6 @@ public enum EmbedderError: LocalizedError {
2624
/// The configuration file exists but contains invalid JSON or missing required fields.
2725
case configurationDecodingError(String, String, DecodingError)
2826

29-
/// Thrown when the tokenizer configuration is missing from the model bundle or Hub.
30-
case missingTokenizerConfig
31-
3227
/// A human-readable description of the error.
3328
public var errorDescription: String? {
3429
switch self {
@@ -39,8 +34,6 @@ public enum EmbedderError: LocalizedError {
3934
case .configurationDecodingError(let file, let modelName, let decodingError):
4035
let errorDetail = extractDecodingErrorDetail(decodingError)
4136
return "Failed to parse \(file) for model '\(modelName)': \(errorDetail)"
42-
case .missingTokenizerConfig:
43-
return "Missing tokenizer configuration"
4437
}
4538
}
4639

@@ -70,43 +63,48 @@ public enum EmbedderError: LocalizedError {
7063
}
7164
}
7265

73-
/// Prepares the local model directory by downloading files from the Hub or resolving a local path.
74-
///
75-
/// If the `ModelConfiguration` identifies a remote repo, this function downloads weights
76-
/// (`.safetensors`) and config files. It includes a fallback mechanism: if the user is
77-
/// offline or unauthorized, it attempts to resolve the files from the local cache.
66+
/// Resolve model and tokenizer directories from a ``ModelConfiguration``
67+
/// using a ``Downloader``.
7868
///
7969
/// - Parameters:
80-
/// - hub: The `HubApi` instance for managing downloads.
70+
/// - downloader: The downloader to use for fetching remote resources.
8171
/// - configuration: The configuration identifying the model.
72+
/// - useLatest: When true, always checks the provider for updates.
8273
/// - progressHandler: A closure to monitor download progress.
83-
/// - Returns: A `URL` pointing to the directory containing model files.
84-
func prepareModelDirectory(
85-
hub: HubApi,
74+
/// - Returns: A tuple of (modelDirectory, tokenizerDirectory).
75+
func resolveDirectories(
76+
from downloader: any Downloader,
8677
configuration: ModelConfiguration,
78+
useLatest: Bool = false,
8779
progressHandler: @Sendable @escaping (Progress) -> Void
88-
) async throws -> URL {
89-
do {
90-
switch configuration.id {
91-
case .id(let id):
92-
let repo = Hub.Repo(id: id)
93-
let modelFiles = ["*.safetensors", "config.json", "*/config.json"]
94-
return try await hub.snapshot(
95-
from: repo, matching: modelFiles, progressHandler: progressHandler)
96-
97-
case .directory(let directory):
98-
return directory
99-
}
100-
} catch Hub.HubClientError.authorizationRequired {
101-
return configuration.modelDirectory(hub: hub)
102-
} catch {
103-
let nserror = error as NSError
104-
if nserror.domain == NSURLErrorDomain && nserror.code == NSURLErrorNotConnectedToInternet {
105-
return configuration.modelDirectory(hub: hub)
106-
} else {
107-
throw error
108-
}
80+
) async throws -> (modelDirectory: URL, tokenizerDirectory: URL) {
81+
let modelDirectory: URL
82+
switch configuration.id {
83+
case .id(let id, let revision):
84+
modelDirectory = try await downloader.download(
85+
id: id, revision: revision,
86+
matching: modelDownloadPatterns,
87+
useLatest: useLatest,
88+
progressHandler: progressHandler)
89+
case .directory(let directory):
90+
modelDirectory = directory
91+
}
92+
93+
let tokenizerDirectory: URL
94+
switch configuration.tokenizerSource {
95+
case .id(let id, let revision):
96+
tokenizerDirectory = try await downloader.download(
97+
id: id, revision: revision,
98+
matching: tokenizerDownloadPatterns,
99+
useLatest: useLatest,
100+
progressHandler: { _ in })
101+
case .directory(let directory):
102+
tokenizerDirectory = directory
103+
case nil:
104+
tokenizerDirectory = modelDirectory
109105
}
106+
107+
return (modelDirectory, tokenizerDirectory)
110108
}
111109

112110
/// Asynchronously loads the `EmbeddingModel` and its associated `Tokenizer`.
@@ -116,19 +114,23 @@ func prepareModelDirectory(
116114
/// structure is being built synchronously.
117115
///
118116
/// - Parameters:
119-
/// - hub: The `HubApi` instance (defaults to a new instance).
117+
/// - downloader: The ``Downloader`` to use for fetching remote resources.
120118
/// - configuration: The model configuration.
119+
/// - useLatest: When true, always checks the provider for updates.
121120
/// - progressHandler: A closure for tracking download progress.
122121
/// - Returns: A tuple containing the initialized `EmbeddingModel` and `Tokenizer`.
123122
public func load(
124-
hub: HubApi = defaultHubApi,
123+
from downloader: any Downloader,
124+
using tokenizerLoader: any TokenizerLoader,
125125
configuration: ModelConfiguration,
126+
useLatest: Bool = false,
126127
progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
127128
) async throws -> (EmbeddingModel, Tokenizer) {
128-
let modelDirectory = try await prepareModelDirectory(
129-
hub: hub, configuration: configuration, progressHandler: progressHandler)
129+
let (modelDirectory, tokenizerDirectory) = try await resolveDirectories(
130+
from: downloader, configuration: configuration, useLatest: useLatest,
131+
progressHandler: progressHandler)
130132

131-
async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
133+
async let tokenizerTask = tokenizerLoader.load(from: tokenizerDirectory)
132134
let model = try loadSynchronous(modelDirectory: modelDirectory, modelName: configuration.name)
133135
let tokenizer = try await tokenizerTask
134136

@@ -213,17 +215,65 @@ func loadSynchronous(modelDirectory: URL, modelName: String) throws -> Embedding
213215
/// or tasks may need to access the embedding model simultaneously.
214216
///
215217
/// - Parameters:
216-
/// - hub: The `HubApi` instance.
218+
/// - downloader: The ``Downloader`` to use for fetching remote resources.
217219
/// - configuration: The model configuration.
220+
/// - useLatest: When true, always checks the provider for updates.
218221
/// - progressHandler: A closure for tracking download progress.
219222
/// - Returns: A thread-safe `ModelContainer` instance.
220223
public func loadModelContainer(
221-
hub: HubApi = defaultHubApi,
224+
from downloader: any Downloader,
225+
using tokenizerLoader: any TokenizerLoader,
222226
configuration: ModelConfiguration,
227+
useLatest: Bool = false,
223228
progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
224229
) async throws -> ModelContainer {
225-
let modelDirectory = try await prepareModelDirectory(
226-
hub: hub, configuration: configuration, progressHandler: progressHandler)
230+
let (modelDirectory, tokenizerDirectory) = try await resolveDirectories(
231+
from: downloader, configuration: configuration, useLatest: useLatest,
232+
progressHandler: progressHandler)
233+
227234
return try await ModelContainer(
228-
hub: hub, modelDirectory: modelDirectory, configuration: configuration)
235+
modelDirectory: modelDirectory,
236+
tokenizerDirectory: tokenizerDirectory,
237+
configuration: configuration,
238+
tokenizerLoader: tokenizerLoader)
239+
}
240+
241+
/// Load an embedding model from a local directory.
242+
///
243+
/// No downloader is needed — the model and tokenizer are loaded from
244+
/// the given directory.
245+
///
246+
/// - Parameter directory: The local directory containing model files.
247+
/// - Returns: A tuple containing the initialized `EmbeddingModel` and `Tokenizer`.
248+
public func load(
249+
from directory: URL,
250+
using tokenizerLoader: any TokenizerLoader
251+
) async throws -> (EmbeddingModel, Tokenizer) {
252+
let name =
253+
directory.deletingLastPathComponent().lastPathComponent + "/"
254+
+ directory.lastPathComponent
255+
async let tokenizerTask = tokenizerLoader.load(from: directory)
256+
let model = try loadSynchronous(modelDirectory: directory, modelName: name)
257+
let tokenizer = try await tokenizerTask
258+
return (model, tokenizer)
259+
}
260+
261+
/// Load an embedding model container from a local directory.
262+
///
263+
/// No downloader is needed — the model and tokenizer are loaded from
264+
/// the given directory.
265+
///
266+
/// - Parameters:
267+
/// - directory: The local directory containing model files.
268+
/// - tokenizerLoader: The ``TokenizerLoader`` to use for loading the tokenizer.
269+
/// - Returns: A thread-safe `ModelContainer` instance.
270+
public func loadModelContainer(
271+
from directory: URL,
272+
using tokenizerLoader: any TokenizerLoader
273+
) async throws -> ModelContainer {
274+
try await ModelContainer(
275+
modelDirectory: directory,
276+
tokenizerDirectory: directory,
277+
configuration: ModelConfiguration(directory: directory),
278+
tokenizerLoader: tokenizerLoader)
229279
}

Libraries/MLXEmbedders/Models.swift

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright © 2024 Apple Inc.
22

33
import Foundation
4-
import Hub
4+
import MLXLMCommon
55

66
/// A registry and configuration provider for embedding models.
77
///
@@ -22,7 +22,7 @@ public struct ModelConfiguration: Sendable {
2222
/// The backing storage for the model's location.
2323
public enum Identifier: Sendable {
2424
/// A Hugging Face Hub repository identifier (e.g., "BAAI/bge-small-en-v1.5").
25-
case id(String)
25+
case id(String, revision: String = "main")
2626
/// A file system URL pointing to a local model directory.
2727
case directory(URL)
2828
}
@@ -36,67 +36,44 @@ public struct ModelConfiguration: Sendable {
3636
/// it returns a path-based name (e.g., "ParentDir/ModelDir").
3737
public var name: String {
3838
switch id {
39-
case .id(let string):
39+
case .id(let string, _):
4040
string
4141
case .directory(let url):
4242
url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent
4343
}
4444
}
4545

46-
/// An optional alternate Hub ID to use specifically for loading the tokenizer.
46+
/// Where to load the tokenizer from when it differs from the model directory.
4747
///
48-
/// Use this if the model weights and tokenizer configuration are hosted in different repositories.
49-
public let tokenizerId: String?
50-
51-
/// An optional override string for specifying a specific tokenizer implementation.
52-
///
53-
/// This is useful for providing compatibility hints to `swift-tokenizers` before
54-
/// official support is updated.
55-
public let overrideTokenizer: String?
48+
/// - `.id`: download from a remote provider (requires a ``Downloader``)
49+
/// - `.directory`: load from a local path
50+
/// - `nil`: use the same directory as the model
51+
public let tokenizerSource: TokenizerSource?
5652

5753
/// Initializes a configuration using a Hub repository ID.
5854
/// - Parameters:
5955
/// - id: The Hugging Face repo ID.
60-
/// - tokenizerId: Optional alternate repo for the tokenizer.
61-
/// - overrideTokenizer: Optional specific tokenizer implementation name.
56+
/// - revision: The Git revision to use (defaults to "main").
57+
/// - tokenizerSource: Optional alternate source for the tokenizer.
6258
public init(
6359
id: String,
64-
tokenizerId: String? = nil,
65-
overrideTokenizer: String? = nil
60+
revision: String = "main",
61+
tokenizerSource: TokenizerSource? = nil
6662
) {
67-
self.id = .id(id)
68-
self.tokenizerId = tokenizerId
69-
self.overrideTokenizer = overrideTokenizer
63+
self.id = .id(id, revision: revision)
64+
self.tokenizerSource = tokenizerSource
7065
}
7166

7267
/// Initializes a configuration using a local directory.
7368
/// - Parameters:
7469
/// - directory: The `URL` of the model on disk.
75-
/// - tokenizerId: Optional alternate repo for the tokenizer.
76-
/// - overrideTokenizer: Optional specific tokenizer implementation name.
70+
/// - tokenizerSource: Optional alternate source for the tokenizer.
7771
public init(
7872
directory: URL,
79-
tokenizerId: String? = nil,
80-
overrideTokenizer: String? = nil
73+
tokenizerSource: TokenizerSource? = nil
8174
) {
8275
self.id = .directory(directory)
83-
self.tokenizerId = tokenizerId
84-
self.overrideTokenizer = overrideTokenizer
85-
}
86-
87-
/// Resolves the local file system URL where the model is (or will be) stored.
88-
///
89-
/// - Parameter hub: The `HubApi` used to resolve Hub paths.
90-
/// - Returns: A `URL` pointing to the local directory.
91-
public func modelDirectory(hub: HubApi = HubApi()) -> URL {
92-
switch id {
93-
case .id(let id):
94-
let repo = Hub.Repo(id: id)
95-
return hub.localRepoLocation(repo)
96-
97-
case .directory(let directory):
98-
return directory
99-
}
76+
self.tokenizerSource = tokenizerSource
10077
}
10178

10279
// MARK: - Registry Management

0 commit comments

Comments
 (0)