Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Applications/LLMBasic/ChatModel.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright © 2025 Apple Inc.

import MLXLMHuggingFace
import MLXLLM
import MLXLMCommon
import MLXLMTransformers
import SwiftUI

/// which model to load
Expand Down Expand Up @@ -40,7 +42,10 @@ private let generateParameters = GenerateParameters(temperature: 0.5)
case .idle:
let task = Task {
// download and report progress
try await loadModelContainer(configuration: modelConfiguration) { value in
try await loadModelContainer(
from: HubClient.default,
configuration: modelConfiguration
) { value in
Task { @MainActor in
self.progress = value.fractionCompleted
}
Expand Down
33 changes: 5 additions & 28 deletions Applications/LLMEval/ViewModels/LLMEvaluator.swift
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// Copyright © 2025 Apple Inc.

import Hub
import MLX
import MLXLMHuggingFace
import MLXLLM
import MLXLMCommon
import MLXLMTransformers
import Metal
import SwiftUI

Expand Down Expand Up @@ -101,46 +102,22 @@ class LLMEvaluator {

Memory.cacheLimit = 20 * 1024 * 1024

let hub = HubApi(
downloadBase: FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first
)
let hub = HubClient.default

do {
let modelDirectory = try await downloadModel(
hub: hub,
let modelContainer = try await LLMModelFactory.shared.loadContainer(
from: hub,
configuration: modelConfiguration
) { [weak self] progress in
Task { @MainActor in
self?.updateDownloadProgress(progress)
}
}

// Verify the download succeeded by checking for model files
let fileManager = FileManager.default
let directoryExists = fileManager.fileExists(atPath: modelDirectory.path)
let contents = (try? fileManager.contentsOfDirectory(atPath: modelDirectory.path)) ?? []
let hasSafetensors = contents.contains { $0.hasSuffix(".safetensors") }

if !directoryExists || !hasSafetensors {
throw NSError(
domain: "LLMEvaluator",
code: -1,
userInfo: [
NSLocalizedDescriptionKey:
"Model download failed. Please check your network connection and try again."
]
)
}

modelInfo = "Loading \(modelName)..."
downloadProgress = nil
totalSize = nil

let modelContainer = try await LLMModelFactory.shared.loadContainer(
hub: hub,
configuration: modelConfiguration
) { _ in }

let numParams = await modelContainer.perform { $0.model.numParameters() }

self.prompt = PresetPrompts.all[0].prompt
Expand Down
1 change: 0 additions & 1 deletion Applications/LLMEval/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import MLXLLM
import MLXLMCommon
import Metal
import SwiftUI
import Tokenizers

struct ContentView: View {
@Environment(DeviceStat.self) private var deviceStat
Expand Down
6 changes: 4 additions & 2 deletions Applications/LoRATrainingExample/ContentView.swift
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Copyright © 2024 Apple Inc.

import MLX
import MLXLMHuggingFace
import MLXLLM
import MLXLMCommon
import MLXLMTransformers
import MLXNN
import MLXOptimizers
import SwiftUI
import Tokenizers

struct ContentView: View {

Expand Down Expand Up @@ -142,6 +143,7 @@ class LoRAEvaluator {
}

let modelContainer = try await LLMModelFactory.shared.loadContainer(
from: HubClient.default,
configuration: modelConfiguration
) {
progress in
Expand Down Expand Up @@ -269,7 +271,7 @@ class LoRAEvaluator {
input: input, parameters: generateParameters, context: context
) { tokens in
if tokens.count % evaluateShowEvery == 0 {
let fullOutput = context.tokenizer.decode(tokens: tokens)
let fullOutput = context.tokenizer.decode(tokenIds: tokens)
Task { @MainActor in
self.output = fullOutput
}
Expand Down
4 changes: 3 additions & 1 deletion Applications/MLXChatExample/Services/MLXService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import Foundation
import MLX
import MLXLMHuggingFace
import MLXLLM
import MLXLMCommon
import MLXLMTransformers
import MLXVLM

/// A service class that manages machine learning models for text and vision-language tasks.
Expand Down Expand Up @@ -65,7 +67,7 @@ class MLXService {

// Load model and track download progress
let container = try await factory.loadContainer(
hub: .default, configuration: model.configuration
from: HubClient.default, configuration: model.configuration
) { progress in
Task { @MainActor in
self.modelDownloadProgress = progress
Expand Down
24 changes: 0 additions & 24 deletions Applications/MLXChatExample/Support/HubApi+default.swift

This file was deleted.

2 changes: 1 addition & 1 deletion Applications/MLXChatExample/ViewModels/ChatViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ChatViewModel {
case .info(let info):
// Update performance metrics
generateCompletionInfo = info
case .toolCall(let call):
case .toolCall:
break
}
}
Expand Down
76 changes: 53 additions & 23 deletions Libraries/StableDiffusion/Load.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright © 2024 Apple Inc.

import Foundation
import Hub
import HuggingFace
import MLX
import MLXNN

Expand Down Expand Up @@ -112,24 +112,46 @@ public struct StableDiffusionConfiguration: Sendable {
let files: [FileKey: String]
public let defaultParameters: @Sendable () -> EvaluateParameters
let factory:
@Sendable (HubApi, StableDiffusionConfiguration, LoadConfiguration) throws ->
@Sendable (HubClient, StableDiffusionConfiguration, LoadConfiguration) throws ->
StableDiffusion

enum Error: LocalizedError {
case invalidRepositoryID(String)
case missingDownloadedSnapshot(String)

var errorDescription: String? {
switch self {
case .invalidRepositoryID(let id):
return "Invalid Hugging Face repository ID: '\(id)'."
case .missingDownloadedSnapshot(let id):
return "Model files for '\(id)' are not downloaded. Call download() first."
}
}
}

fileprivate var repoID: Repo.ID {
get throws {
guard let repoID = Repo.ID(rawValue: id) else {
throw Error.invalidRepositoryID(id)
}
return repoID
}
}

public func download(
hub: HubApi = HubApi(), progressHandler: @escaping (Progress) -> Void = { _ in }
hub: HubClient = .default, progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws {
let repo = Hub.Repo(id: self.id)
try await hub.snapshot(
from: repo, matching: Array(files.values), progressHandler: progressHandler)
_ = try await hub.downloadSnapshot(
of: repoID, matching: Array(files.values), progressHandler: progressHandler)
}

public func textToImageGenerator(hub: HubApi = HubApi(), configuration: LoadConfiguration)
public func textToImageGenerator(hub: HubClient = .default, configuration: LoadConfiguration)
throws -> TextToImageGenerator?
{
try factory(hub, self, configuration) as? TextToImageGenerator
}

public func imageToImageGenerator(hub: HubApi = HubApi(), configuration: LoadConfiguration)
public func imageToImageGenerator(hub: HubClient = .default, configuration: LoadConfiguration)
throws -> ImageToImageGenerator?
{
try factory(hub, self, configuration) as? ImageToImageGenerator
Expand Down Expand Up @@ -373,64 +395,72 @@ func loadWeights(

// MARK: - Loading

func resolve(hub: HubApi, configuration: StableDiffusionConfiguration, key: FileKey) -> URL {
func resolve(hub: HubClient, configuration: StableDiffusionConfiguration, key: FileKey) throws -> URL {
precondition(
configuration.files[key] != nil, "configuration \(configuration.id) missing key: \(key)")
let repo = Hub.Repo(id: configuration.id)
let directory = hub.localRepoLocation(repo)
let repo = try configuration.repoID
guard let cache = hub.cache,
let revision = cache.resolveRevision(repo: repo, kind: .model, ref: "main")
else {
throw StableDiffusionConfiguration.Error.missingDownloadedSnapshot(configuration.id)
}
let directory = try cache.snapshotPath(repo: repo, kind: .model, commitHash: revision)
guard FileManager.default.fileExists(atPath: directory.path) else {
throw StableDiffusionConfiguration.Error.missingDownloadedSnapshot(configuration.id)
}
return directory.appending(component: configuration.files[key]!)
}

func loadConfiguration<T: Decodable>(
hub: HubApi, configuration: StableDiffusionConfiguration, key: FileKey, type: T.Type
hub: HubClient, configuration: StableDiffusionConfiguration, key: FileKey, type: T.Type
) throws -> T {
let url = resolve(hub: hub, configuration: configuration, key: key)
let url = try resolve(hub: hub, configuration: configuration, key: key)
return try JSONDecoder().decode(T.self, from: Data(contentsOf: url))
}

func loadUnet(hub: HubApi, configuration: StableDiffusionConfiguration, dType: DType) throws
func loadUnet(hub: HubClient, configuration: StableDiffusionConfiguration, dType: DType) throws
-> UNetModel
{
let unetConfiguration = try loadConfiguration(
hub: hub, configuration: configuration, key: .unetConfig, type: UNetConfiguration.self)
let model = UNetModel(configuration: unetConfiguration)

let weightsURL = resolve(hub: hub, configuration: configuration, key: .unetWeights)
let weightsURL = try resolve(hub: hub, configuration: configuration, key: .unetWeights)
try loadWeights(url: weightsURL, model: model, mapper: unetRemap, dType: dType)

return model
}

func loadTextEncoder(
hub: HubApi, configuration: StableDiffusionConfiguration,
hub: HubClient, configuration: StableDiffusionConfiguration,
configKey: FileKey = .textEncoderConfig, weightsKey: FileKey = .textEncoderWeights, dType: DType
) throws -> CLIPTextModel {
let clipConfiguration = try loadConfiguration(
hub: hub, configuration: configuration, key: configKey,
type: CLIPTextModelConfiguration.self)
let model = CLIPTextModel(configuration: clipConfiguration)

let weightsURL = resolve(hub: hub, configuration: configuration, key: weightsKey)
let weightsURL = try resolve(hub: hub, configuration: configuration, key: weightsKey)
try loadWeights(url: weightsURL, model: model, mapper: clipRemap, dType: dType)

return model
}

func loadAutoEncoder(hub: HubApi, configuration: StableDiffusionConfiguration, dType: DType) throws
func loadAutoEncoder(hub: HubClient, configuration: StableDiffusionConfiguration, dType: DType) throws
-> Autoencoder
{
let autoEncoderConfiguration = try loadConfiguration(
hub: hub, configuration: configuration, key: .vaeConfig, type: AutoencoderConfiguration.self
)
let model = Autoencoder(configuration: autoEncoderConfiguration)

let weightsURL = resolve(hub: hub, configuration: configuration, key: .vaeWeights)
let weightsURL = try resolve(hub: hub, configuration: configuration, key: .vaeWeights)
try loadWeights(url: weightsURL, model: model, mapper: vaeRemap, dType: dType)

return model
}

func loadDiffusionConfiguration(hub: HubApi, configuration: StableDiffusionConfiguration) throws
func loadDiffusionConfiguration(hub: HubClient, configuration: StableDiffusionConfiguration) throws
-> DiffusionConfiguration
{
try loadConfiguration(
Expand All @@ -441,11 +471,11 @@ func loadDiffusionConfiguration(hub: HubApi, configuration: StableDiffusionConfi
// MARK: - Tokenizer

func loadTokenizer(
hub: HubApi, configuration: StableDiffusionConfiguration,
hub: HubClient, configuration: StableDiffusionConfiguration,
vocabulary: FileKey = .tokenizerVocabulary, merges: FileKey = .tokenizerMerges
) throws -> CLIPTokenizer {
let vocabularyURL = resolve(hub: hub, configuration: configuration, key: vocabulary)
let mergesURL = resolve(hub: hub, configuration: configuration, key: merges)
let vocabularyURL = try resolve(hub: hub, configuration: configuration, key: vocabulary)
let mergesURL = try resolve(hub: hub, configuration: configuration, key: merges)

let vocabulary = try JSONDecoder().decode(
[String: Int].self, from: Data(contentsOf: vocabularyURL))
Expand Down
8 changes: 4 additions & 4 deletions Libraries/StableDiffusion/StableDiffusion.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright © 2024 Apple Inc.

import Foundation
import Hub
import HuggingFace
import MLX
import MLXNN

Expand Down Expand Up @@ -214,7 +214,7 @@ open class StableDiffusion {
let tokenizer: CLIPTokenizer

internal init(
hub: HubApi, configuration: StableDiffusionConfiguration, dType: DType,
hub: HubClient, configuration: StableDiffusionConfiguration, dType: DType,
diffusionConfiguration: DiffusionConfiguration? = nil, unet: UNetModel? = nil,
textEncoder: CLIPTextModel? = nil, autoencoder: Autoencoder? = nil,
sampler: SimpleEulerSampler? = nil, tokenizer: CLIPTokenizer? = nil
Expand Down Expand Up @@ -299,7 +299,7 @@ open class StableDiffusion {
/// Implementation of ``StableDiffusion`` for the `stabilityai/stable-diffusion-2-1-base` model.
open class StableDiffusionBase: StableDiffusion, TextToImageGenerator {

public init(hub: HubApi, configuration: StableDiffusionConfiguration, dType: DType) throws {
public init(hub: HubClient, configuration: StableDiffusionConfiguration, dType: DType) throws {
try super.init(hub: hub, configuration: configuration, dType: dType)
}

Expand Down Expand Up @@ -345,7 +345,7 @@ open class StableDiffusionXL: StableDiffusion, TextToImageGenerator, ImageToImag
let textEncoder2: CLIPTextModel
let tokenizer2: CLIPTokenizer

public init(hub: HubApi, configuration: StableDiffusionConfiguration, dType: DType) throws {
public init(hub: HubClient, configuration: StableDiffusionConfiguration, dType: DType) throws {
let diffusionConfiguration = try loadConfiguration(
hub: hub, configuration: configuration, key: .diffusionConfig,
type: DiffusionConfiguration.self)
Expand Down
4 changes: 2 additions & 2 deletions Libraries/StableDiffusion/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ struct Bigram: Hashable {
/// - https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
/// - https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
///
/// Ideally this would be a tokenizer from `swift-transformers` but this is too special purpose to be representable in
/// what exists there (at time of writing).
/// Ideally this would use a shared tokenizer package, but this is too special purpose to be
/// representable in current public APIs.
class CLIPTokenizer {

let pattern =
Expand Down
Loading