-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathEncoder.swift
More file actions
121 lines (104 loc) · 4.58 KB
/
Encoder.swift
File metadata and controls
121 lines (104 loc) · 4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
import CoreML
/// A encoder model which produces latent samples from RGB images
@available(iOS 16.2, macOS 13.1, *)
public struct Encoder: ResourceManaging {
public enum Error: String, Swift.Error {
case sampleInputShapeNotCorrect
}
/// VAE encoder model + post math and adding noise from schedular
var model: ManagedMLModel
/// Create encoder from Core ML model
///
/// - Parameters:
/// - url: Location of compiled VAE encoder Core ML model
/// - configuration: configuration to be used when the model is loaded
/// - Returns: An encoder that will lazily load its required resources when needed or requested
public init(modelAt url: URL, configuration: MLModelConfiguration) {
self.model = ManagedMLModel(modelAt: url, configuration: configuration)
}
/// Ensure the model has been loaded into memory
public func loadResources() throws {
try model.loadResources()
}
/// Unload the underlying model to free up memory
public func unloadResources() {
model.unloadResources()
}
/// Prediction queue
let queue = DispatchQueue(label: "encoder.predict")
/// Encode image into latent sample
///
/// - Parameters:
/// - image: Input image
/// - scaleFactor: scalar multiplier on latents before encoding image
/// - random
/// - Returns: The encoded latent space as MLShapedArray
public func encode(
_ image: CGImage,
scaleFactor: Float32,
random: inout RandomSource
) throws -> MLShapedArray<Float32> {
let imageData = try image.plannerRGBShapedArray(minValue: -1.0, maxValue: 1.0)
guard imageData.shape == inputShape else {
// TODO: Consider auto resizing and croping similar to how Vision or CoreML auto-generated Swift code can accomplish with `MLFeatureValue`
throw Error.sampleInputShapeNotCorrect
}
let dict = [inputName: MLMultiArray(imageData)]
let input = try MLDictionaryFeatureProvider(dictionary: dict)
let result = try model.perform { model in
try model.prediction(from: input)
}
let outputName = result.featureNames.first!
let outputValue = result.featureValue(for: outputName)!.multiArrayValue!
let output = MLShapedArray<Float32>(outputValue)
// DiagonalGaussianDistribution
let mean = output[0][0..<4]
let logvar = MLShapedArray<Float32>(
scalars: output[0][4..<8].scalars.map { min(max($0, -30), 20) },
shape: mean.shape
)
let std = MLShapedArray<Float32>(
scalars: logvar.scalars.map { exp(0.5 * $0) },
shape: logvar.shape
)
let latent = MLShapedArray<Float32>(
scalars: zip(mean.scalars, std.scalars).map {
Float32(random.nextNormal(mean: Double($0), stdev: Double($1)))
},
shape: logvar.shape
)
// Reference pipeline scales the latent after encoding
let latentScaled = MLShapedArray<Float32>(
scalars: latent.scalars.map { $0 * scaleFactor },
shape: [1] + latent.shape
)
return latentScaled
}
var inputDescription: MLFeatureDescription {
try! model.perform { model in
guard let zInputDescription = model.modelDescription.inputDescriptionsByName["z"] else {
let modelVersion = model.modelDescription.metadata[MLModelMetadataKey.versionString] ?? "unknown version"
fatalError(
"""
The VAE encoder of this model (\(modelVersion)) is not compatible \
with this version of `ml-stable-diffusion`. Please, convert the VAE encoder again using the latest \
version of this package and following the instructions here: \
https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml
We'd appreciate if you could then submit the new VAE encoder as a PR to the repo from which this model \
was downloaded.
""")
}
return zInputDescription
}
}
var inputName: String {
inputDescription.name
}
/// The expected shape of the models latent sample input
var inputShape: [Int] {
inputDescription.multiArrayConstraint!.shape.map { $0.intValue }
}
}