Skip to content

Commit 6512a39

Browse files
committed
Fix concurrency errors in MLTensor extension
1 parent c201da3 commit 6512a39

4 files changed

Lines changed: 111 additions & 66 deletions

File tree

Sources/WhisperKit/Core/Text/TokenSampler.swift

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import CoreML
66
import Foundation
77

88
public protocol TokenSampling {
9-
func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult
9+
func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult
1010
func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult
1111
}
1212

@@ -39,7 +39,7 @@ open class GreedyTokenSampler: TokenSampling {
3939

4040
#if canImport(CoreML.MLState)
4141
@available(macOS 15, iOS 18, watchOS 11, visionOS 2, *)
42-
private func sampleWithMLTensor(logits: MLMultiArray) -> (token: Int, logprob: Float) {
42+
private func sampleWithMLTensor(logits: MLMultiArray) async -> (token: Int, logprob: Float) {
4343
// Use MLTensor operations if available for sampling
4444
// Reference: https://github.com/huggingface/swift-transformers/blob/preview/Sources/Generation/Decoders.swift
4545
var logitsTensor = MLTensor(MLShapedArray<FloatType>(logits)).cast(to: Float.self)
@@ -76,9 +76,11 @@ open class GreedyTokenSampler: TokenSampling {
7676
nextLogprobTensor = softmaxScores.gathering(atIndices: nextTokenTensor, alongAxis: -1).log()
7777
}
7878

79+
async let nextTokenArray = nextTokenTensor.asIntArray()
80+
async let nextLogprobArray = nextLogprobTensor.asFloatArray()
7981
return (
80-
token: nextTokenTensor.asIntArray()[0],
81-
logprob: nextLogprobTensor.asFloatArray()[0]
82+
token: await nextTokenArray[0],
83+
logprob: await nextLogprobArray[0]
8284
)
8385
}
8486
#endif
@@ -212,15 +214,15 @@ open class GreedyTokenSampler: TokenSampling {
212214
return (token: nextToken!, logprob: nextLogprob)
213215
}
214216

215-
public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult {
217+
public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult {
216218
var nextTokens = tokens
217219
var nextLogprobs = logProbs
218220
var completed = false
219221

220222
var result: (token: Int, logprob: Float)
221223
#if canImport(CoreML.MLState)
222224
if #available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *) {
223-
result = sampleWithMLTensor(logits: logits)
225+
result = await sampleWithMLTensor(logits: logits)
224226
} else {
225227
result = sampleWithBNNS(logits: logits)
226228
}
@@ -278,7 +280,7 @@ open class BeamSearchTokenSampler: TokenSampling {
278280
finishedSequences = []
279281
}
280282

281-
public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult {
283+
public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) async -> SamplingResult {
282284
// TODO: Implement
283285
fatalError("Not implemented: \(#function)")
284286
}

Sources/WhisperKit/Core/TextDecoder.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
686686

687687
let samplingStartTime = Date()
688688

689-
let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)
689+
let sampleResult = await tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)
690690

691691
nextToken = sampleResult.tokens.last!
692692
logProbs = sampleResult.logProbs
@@ -838,7 +838,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
838838

839839
let samplingStartTime = Date()
840840

841-
let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)
841+
let sampleResult = await tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs)
842842

843843
nextToken = sampleResult.tokens.last!
844844
let nextTokenLogProb = sampleResult.logProbs.last!

Sources/WhisperKit/Utilities/Extensions+Public.swift

Lines changed: 26 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -160,69 +160,38 @@ public extension MLMultiArray {
160160
#if canImport(CoreML.MLState)
161161
@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
162162
public extension MLTensor {
163-
func asIntArray() -> [Int] {
164-
let semaphore = DispatchSemaphore(value: 0)
165-
var result: [Int] = []
166-
167-
Task(priority: .high) {
168-
result = await self.shapedArray(of: Int32.self).scalars.map { Int($0) }
169-
semaphore.signal()
170-
}
171-
172-
semaphore.wait()
173-
return result
163+
func asIntArray() async -> [Int] {
164+
await shapedArray(of: Int32.self).scalars.map { Int($0) }
174165
}
175166

176-
func asFloatArray() -> [Float] {
177-
let semaphore = DispatchSemaphore(value: 0)
178-
let tensorType = self.scalarType
179-
180-
var result: [Float] = []
181-
182-
Task(priority: .high) {
183-
switch tensorType {
184-
case is Float32.Type:
185-
result = await self.shapedArray(of: Float32.self).scalars.map { Float($0) }
186-
case is FloatType.Type:
187-
result = await self.shapedArray(of: FloatType.self).scalars.map { Float($0) }
188-
case is Float.Type:
189-
result = await self.shapedArray(of: Float.self).scalars.map { Float($0) }
190-
case is Int32.Type:
191-
result = await self.shapedArray(of: Int32.self).scalars.map { Float($0) }
192-
default:
193-
fatalError("Unsupported data type")
194-
}
195-
semaphore.signal()
167+
func asFloatArray() async -> [Float] {
168+
switch scalarType {
169+
case is Float32.Type:
170+
await shapedArray(of: Float32.self).scalars.map { Float($0) }
171+
case is FloatType.Type:
172+
await shapedArray(of: FloatType.self).scalars.map { Float($0) }
173+
case is Float.Type:
174+
await shapedArray(of: Float.self).scalars.map { Float($0) }
175+
case is Int32.Type:
176+
await shapedArray(of: Int32.self).scalars.map { Float($0) }
177+
default:
178+
fatalError("Unsupported data type")
196179
}
197-
198-
semaphore.wait()
199-
return result
200180
}
201181

202-
func asMLMultiArray() -> MLMultiArray {
203-
let semaphore = DispatchSemaphore(value: 0)
204-
let tensorType = self.scalarType
205-
206-
var result = try! MLMultiArray(shape: [1], dataType: .float16, initialValue: 0.0)
207-
208-
Task(priority: .high) {
209-
switch tensorType {
210-
case is Float32.Type:
211-
result = MLMultiArray(await self.shapedArray(of: Float32.self))
212-
case is FloatType.Type:
213-
result = MLMultiArray(await self.shapedArray(of: FloatType.self))
214-
case is Float.Type:
215-
result = MLMultiArray(await self.shapedArray(of: Float.self))
216-
case is Int32.Type:
217-
result = MLMultiArray(await self.shapedArray(of: Int32.self))
218-
default:
219-
fatalError("Unsupported data type")
220-
}
221-
semaphore.signal()
182+
func asMLMultiArray() async -> MLMultiArray {
183+
switch scalarType {
184+
case is Float32.Type:
185+
MLMultiArray(await shapedArray(of: Float32.self))
186+
case is FloatType.Type:
187+
MLMultiArray(await shapedArray(of: FloatType.self))
188+
case is Float.Type:
189+
MLMultiArray(await shapedArray(of: Float.self))
190+
case is Int32.Type:
191+
MLMultiArray(await shapedArray(of: Int32.self))
192+
default:
193+
fatalError("Unsupported data type")
222194
}
223-
224-
semaphore.wait()
225-
return result
226195
}
227196
}
228197
#endif
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// For licensing see accompanying LICENSE.md file.
2+
// Copyright © 2024 Argmax, Inc. All rights reserved.
3+
4+
#if canImport(CoreML.MLState)
5+
import CoreML
6+
@testable import WhisperKit
7+
import XCTest
8+
9+
@available(macOS 15.0, iOS 18.0, watchOS 11.0, visionOS 2.0, *)
10+
final class MLTensorExtensionsTests: XCTestCase {
11+
func testAsIntArrayReturnsExpectedScalars() async {
12+
let tensor = MLTensor(MLShapedArray<Int32>(scalars: [1, -2, 42], shape: [3]))
13+
14+
let result = await tensor.asIntArray()
15+
16+
XCTAssertEqual(result, [1, -2, 42])
17+
}
18+
19+
func testAsFloatArraySupportsFloat32Tensor() async {
20+
let tensor = MLTensor(MLShapedArray<Float32>(scalars: [0.25, -1.5, 2.0], shape: [3]))
21+
22+
let result = await tensor.asFloatArray()
23+
24+
assertEqual(result, [0.25, -1.5, 2.0], accuracy: 0.0001)
25+
}
26+
27+
func testAsFloatArraySupportsFloatTypeTensor() async {
28+
let expected = [FloatType(0.125), FloatType(-0.75), FloatType(3.5)]
29+
let tensor = MLTensor(MLShapedArray<FloatType>(scalars: expected, shape: [3]))
30+
31+
let result = await tensor.asFloatArray()
32+
33+
assertEqual(result, expected.map(Float.init), accuracy: 0.0001)
34+
}
35+
36+
func testAsFloatArraySupportsInt32Tensor() async {
37+
let tensor = MLTensor(MLShapedArray<Int32>(scalars: [-3, 0, 7], shape: [3]))
38+
39+
let result = await tensor.asFloatArray()
40+
41+
assertEqual(result, [-3, 0, 7], accuracy: 0.0001)
42+
}
43+
44+
func testAsMLMultiArrayRoundTripsFloatTypeTensor() async {
45+
let expected = [FloatType(1.25), FloatType(-0.5), FloatType(3.75)]
46+
let tensor = MLTensor(MLShapedArray<FloatType>(scalars: expected, shape: [3]))
47+
48+
let result = await tensor.asMLMultiArray()
49+
let shapedArray = MLShapedArray<FloatType>(result)
50+
51+
XCTAssertEqual(result.shape, [3])
52+
XCTAssertEqual(shapedArray.scalars.count, expected.count)
53+
assertEqual(shapedArray.scalars.map(Float.init), expected.map(Float.init), accuracy: 0.0001)
54+
}
55+
56+
func testAsMLMultiArrayRoundTripsInt32Tensor() async {
57+
let expected: [Int32] = [-9, 4, 12]
58+
let tensor = MLTensor(MLShapedArray<Int32>(scalars: expected, shape: [3]))
59+
60+
let result = await tensor.asMLMultiArray()
61+
let shapedArray = MLShapedArray<Int32>(result)
62+
63+
XCTAssertEqual(result.shape, [3])
64+
XCTAssertEqual(shapedArray.scalars, expected)
65+
}
66+
67+
private func assertEqual(_ lhs: [Float], _ rhs: [Float], accuracy: Float) {
68+
XCTAssertEqual(lhs.count, rhs.count)
69+
for (actual, expected) in zip(lhs, rhs) {
70+
XCTAssertEqual(actual, expected, accuracy: accuracy)
71+
}
72+
}
73+
}
74+
#endif

0 commit comments

Comments
 (0)