Skip to content

Commit 25b00d4

Browse files
authored
switch to swift 6 -- prevent concurrency issues, fix concurrency issues (#165)
* switch to swift 6 -- prevent concurrency issues, fix concurrency issues - some concurrency issues snuck in for test support - fix and prevent more from coming in by switching Package to swift 6 - depends on ml-explore/mlx-swift#379 * pick up mlx-swift with fix for save_safetensors concurrency issue
1 parent 2a296f1 commit 25b00d4

4 files changed

Lines changed: 71 additions & 73 deletions

File tree

Libraries/MLXLMCommon/ChatSession.swift

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -499,21 +499,19 @@ public final class ChatSession {
499499
await cache.read { _ in }
500500
}
501501

502-
/// Returns the current KV cache, if one has been built.
502+
/// Visit the current cache value, if realized as a `[KVCache]`.
503503
///
504-
/// Returns `nil` if no generation has occurred yet (cache is still empty) or if the
505-
/// session is in history-rehydration mode and generation has not started.
506-
///
507-
/// The returned array holds references to the live cache objects — do not use them
508-
/// concurrently with an active ``respond(to:role:images:videos:)`` or
509-
/// ``streamResponse(_:)`` call on the same session. To persist the cache
510-
/// across process launches, use ``saveCache(to:)`` instead.
511-
public func currentCache() async -> [KVCache]? {
512-
await cache.read { cache in
513-
if case .kvcache(let array) = cache {
514-
return array
504+
/// This method is meant for test support.
505+
func withCache<R: Sendable>(_ body: @Sendable ([KVCache]?) async throws -> R) async rethrows
506+
-> R?
507+
{
508+
try await cache.read { cache in
509+
switch cache {
510+
case .kvcache(let cache):
511+
return try await body(cache)
512+
default:
513+
return try await body(nil)
515514
}
516-
return nil
517515
}
518516
}
519517

@@ -526,10 +524,14 @@ public final class ChatSession {
526524
/// - Throws: ``ChatSessionError/noCacheAvailable`` if no generation has occurred yet,
527525
/// or any error thrown by the underlying file write
528526
public func saveCache(to url: URL) async throws {
529-
guard let kvCache = await currentCache() else {
530-
throw ChatSessionError.noCacheAvailable
527+
try await cache.read { cache in
528+
switch cache {
529+
case .kvcache(let cache):
530+
try savePromptCache(url: url, cache: cache)
531+
default:
532+
throw ChatSessionError.noCacheAvailable
533+
}
531534
}
532-
try savePromptCache(url: url, cache: kvCache)
533535
}
534536
}
535537

Package.swift

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// swift-tools-version: 5.12
1+
// swift-tools-version: 6.1
22
// The swift-tools-version declares the minimum version of Swift required to build this package.
33

44
import PackageDescription
@@ -26,7 +26,7 @@ let package = Package(
2626
targets: ["MLXEmbedders"]),
2727
],
2828
dependencies: [
29-
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.1")),
29+
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")),
3030
.package(
3131
url: "https://github.com/huggingface/swift-transformers",
3232
.upToNextMinor(from: "1.2.0")
@@ -45,9 +45,6 @@ let package = Package(
4545
path: "Libraries/MLXLLM",
4646
exclude: [
4747
"README.md"
48-
],
49-
swiftSettings: [
50-
.enableExperimentalFeature("StrictConcurrency")
5148
]
5249
),
5350
.target(
@@ -62,9 +59,6 @@ let package = Package(
6259
path: "Libraries/MLXVLM",
6360
exclude: [
6461
"README.md"
65-
],
66-
swiftSettings: [
67-
.enableExperimentalFeature("StrictConcurrency")
6862
]
6963
),
7064
.target(
@@ -78,9 +72,6 @@ let package = Package(
7872
path: "Libraries/MLXLMCommon",
7973
exclude: [
8074
"README.md"
81-
],
82-
swiftSettings: [
83-
.enableExperimentalFeature("StrictConcurrency")
8475
]
8576
),
8677
.target(
@@ -94,9 +85,6 @@ let package = Package(
9485
path: "Libraries/MLXEmbedders",
9586
exclude: [
9687
"README.md"
97-
],
98-
swiftSettings: [
99-
.enableExperimentalFeature("StrictConcurrency")
10088
]
10189
),
10290
.testTarget(
@@ -115,10 +103,7 @@ let package = Package(
115103
exclude: [
116104
"README.md"
117105
],
118-
resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")],
119-
swiftSettings: [
120-
.enableExperimentalFeature("StrictConcurrency")
121-
]
106+
resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")]
122107
),
123108
.testTarget(
124109
name: "MLXLMIntegrationTests",
@@ -135,9 +120,6 @@ let package = Package(
135120
path: "Tests/MLXLMIntegrationTests",
136121
exclude: [
137122
"README.md"
138-
],
139-
swiftSettings: [
140-
.enableExperimentalFeature("StrictConcurrency")
141123
]
142124
),
143125
.testTarget(
@@ -147,10 +129,7 @@ let package = Package(
147129
"MLXVLM",
148130
"MLXLMCommon",
149131
],
150-
path: "Tests/Benchmarks",
151-
swiftSettings: [
152-
.enableExperimentalFeature("StrictConcurrency")
153-
]
132+
path: "Tests/Benchmarks"
154133
),
155134
]
156135
)

Tests/MLXLMTests/ChatSessionTests.swift

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import Foundation
44
import MLX
55
import MLXLLM
6-
import MLXLMCommon
76
import MLXNN
87
import MLXOptimizers
98
import Tokenizers
109
import XCTest
1110

11+
@testable import MLXLMCommon
12+
1213
/// See also ChatSessionIntegrationTests
1314
public class ChatSessionTests: XCTestCase {
1415

@@ -162,36 +163,43 @@ public class ChatSessionTests: XCTestCase {
162163
// MARK: - KV Cache
163164

164165
func testCurrentCacheNilBeforeGeneration() async throws {
165-
let session = ChatSession(model())
166-
let cache = await session.currentCache()
167-
XCTAssertNil(cache)
166+
let session = ChatSession(model(), generateParameters: generationParameters)
167+
await session.withCache { cache in
168+
XCTAssertNil(cache)
169+
}
168170
}
169171

170172
func testCurrentCacheAfterGeneration() async throws {
171-
let session = ChatSession(model())
173+
let session = ChatSession(model(), generateParameters: generationParameters)
172174
_ = try await session.respond(to: "hello")
173-
let cache = await session.currentCache()
174-
XCTAssertNotNil(cache)
175+
await session.withCache { cache in
176+
XCTAssertNotNil(cache)
177+
}
175178
}
176179

177180
func testInitWithKVCache() async throws {
178181
// build a cache from an initial session
179-
let ctx = model()
180-
let initial = ChatSession(ctx)
182+
let container = ModelContainer(context: model())
183+
let initial = ChatSession(container, generateParameters: generationParameters)
181184
_ = try await initial.respond(to: "hello")
182-
guard let cache = await initial.currentCache() else {
183-
XCTFail("expected cache after generation")
184-
return
185-
}
186185

187-
// restore the cache into a new session and verify generation continues
188-
let restored = ChatSession(ctx, cache: cache)
189-
let result = try await restored.respond(to: "hello again")
190-
XCTAssertGreaterThan(result.count, targetLength, result)
186+
try await initial.withCache { [targetLength, generationParameters] cache in
187+
XCTAssertNotNil(cache)
188+
189+
if let cache {
190+
// restore the cache into a new session and verify generation continues
191+
let restored = ChatSession(
192+
container,
193+
cache: cache.map { $0.copy() },
194+
generateParameters: generationParameters)
195+
let result = try await restored.respond(to: "hello again")
196+
XCTAssertGreaterThan(result.count, targetLength, result)
197+
}
198+
}
191199
}
192200

193201
func testSaveCacheThrowsBeforeGeneration() async throws {
194-
let session = ChatSession(model())
202+
let session = ChatSession(model(), generateParameters: generationParameters)
195203
let url = FileManager.default.temporaryDirectory
196204
.appendingPathComponent(UUID().uuidString)
197205
.appendingPathExtension("safetensors")
@@ -205,7 +213,7 @@ public class ChatSessionTests: XCTestCase {
205213

206214
func testSaveAndRestoreCache() async throws {
207215
let ctx = model()
208-
let initial = ChatSession(ctx)
216+
let initial = ChatSession(ctx, generateParameters: generationParameters)
209217
_ = try await initial.respond(to: "hello")
210218

211219
let url = FileManager.default.temporaryDirectory
@@ -214,37 +222,46 @@ public class ChatSessionTests: XCTestCase {
214222
try await initial.saveCache(to: url)
215223

216224
let (loadedCache, _) = try loadPromptCache(url: url)
217-
let restored = ChatSession(ctx, cache: loadedCache)
225+
let restored = ChatSession(
226+
ctx, cache: loadedCache, generateParameters: generationParameters)
218227
let result = try await restored.respond(to: "hello again")
219228
XCTAssertGreaterThan(result.count, targetLength, result)
220229
}
221230

222231
func testCurrentCacheNilForHistorySessionBeforeGeneration() async throws {
223232
// .history state should behave like .empty: no cache until first generation
224233
let history: [Chat.Message] = [.user("hello"), .assistant("hi")]
225-
let session = ChatSession(model(), history: history)
226-
let cache = await session.currentCache()
227-
XCTAssertNil(cache)
234+
let session = ChatSession(
235+
model(), history: history, generateParameters: generationParameters)
236+
await session.withCache { cache in
237+
XCTAssertNil(cache)
238+
}
228239
}
229240

230241
func testCurrentCacheNonNilForHistorySessionAfterGeneration() async throws {
231242
// after generation from .history state, cache transitions to .kvcache
232243
let history: [Chat.Message] = [.user("hello"), .assistant("hi")]
233-
let session = ChatSession(model(), history: history)
244+
let session = ChatSession(
245+
model(),
246+
history: history,
247+
generateParameters: generationParameters)
234248
_ = try await session.respond(to: "hello again")
235-
let cache = await session.currentCache()
236-
XCTAssertNotNil(cache)
249+
await session.withCache { cache in
250+
XCTAssertNotNil(cache)
251+
}
237252
}
238253

239254
func testCurrentCacheNilAfterClear() async throws {
240255
// clear() resets to .empty; currentCache() should return nil again
241-
let session = ChatSession(model())
256+
let session = ChatSession(model(), generateParameters: generationParameters)
242257
_ = try await session.respond(to: "hello")
243-
let cacheBeforeClear = await session.currentCache()
244-
XCTAssertNotNil(cacheBeforeClear)
258+
await session.withCache { cache in
259+
XCTAssertNotNil(cache)
260+
}
245261
await session.clear()
246-
let cacheAfterClear = await session.currentCache()
247-
XCTAssertNil(cacheAfterClear)
262+
await session.withCache { cache in
263+
XCTAssertNil(cache)
264+
}
248265
}
249266

250267
/// something that looks like a view model

Tests/MLXLMTests/KVCacheTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import MLX
33
import MLXLMCommon
44
import Testing
55

6-
private let cacheCreators: [() -> any KVCache] = [
6+
private let cacheCreators: [@Sendable () -> any KVCache] = [
77
{ KVCacheSimple() },
88
{ RotatingKVCache(maxSize: 32) },
99
{ QuantizedKVCache() },

0 commit comments

Comments
 (0)