33import Foundation
44import MLX
55import MLXLLM
6- import MLXLMCommon
76import MLXNN
87import MLXOptimizers
98import Tokenizers
109import XCTest
1110
11+ @testable import MLXLMCommon
12+
1213/// See also ChatSessionIntegrationTests
1314public class ChatSessionTests : XCTestCase {
1415
@@ -162,36 +163,43 @@ public class ChatSessionTests: XCTestCase {
162163 // MARK: - KV Cache
163164
164165 func testCurrentCacheNilBeforeGeneration( ) async throws {
165- let session = ChatSession ( model ( ) )
166- let cache = await session. currentCache ( )
167- XCTAssertNil ( cache)
166+ let session = ChatSession ( model ( ) , generateParameters: generationParameters)
167+ await session. withCache { cache in
168+ XCTAssertNil ( cache)
169+ }
168170 }
169171
170172 func testCurrentCacheAfterGeneration( ) async throws {
171- let session = ChatSession ( model ( ) )
173+ let session = ChatSession ( model ( ) , generateParameters : generationParameters )
172174 _ = try await session. respond ( to: " hello " )
173- let cache = await session. currentCache ( )
174- XCTAssertNotNil ( cache)
175+ await session. withCache { cache in
176+ XCTAssertNotNil ( cache)
177+ }
175178 }
176179
177180 func testInitWithKVCache( ) async throws {
178181 // build a cache from an initial session
179- let ctx = model ( )
180- let initial = ChatSession ( ctx )
182+ let container = ModelContainer ( context : model ( ) )
183+ let initial = ChatSession ( container , generateParameters : generationParameters )
181184 _ = try await initial. respond ( to: " hello " )
182- guard let cache = await initial. currentCache ( ) else {
183- XCTFail ( " expected cache after generation " )
184- return
185- }
186185
187- // restore the cache into a new session and verify generation continues
188- let restored = ChatSession ( ctx, cache: cache)
189- let result = try await restored. respond ( to: " hello again " )
190- XCTAssertGreaterThan ( result. count, targetLength, result)
186+ try await initial. withCache { [ targetLength, generationParameters] cache in
187+ XCTAssertNotNil ( cache)
188+
189+ if let cache {
190+ // restore the cache into a new session and verify generation continues
191+ let restored = ChatSession (
192+ container,
193+ cache: cache. map { $0. copy ( ) } ,
194+ generateParameters: generationParameters)
195+ let result = try await restored. respond ( to: " hello again " )
196+ XCTAssertGreaterThan ( result. count, targetLength, result)
197+ }
198+ }
191199 }
192200
193201 func testSaveCacheThrowsBeforeGeneration( ) async throws {
194- let session = ChatSession ( model ( ) )
202+ let session = ChatSession ( model ( ) , generateParameters : generationParameters )
195203 let url = FileManager . default. temporaryDirectory
196204 . appendingPathComponent ( UUID ( ) . uuidString)
197205 . appendingPathExtension ( " safetensors " )
@@ -205,7 +213,7 @@ public class ChatSessionTests: XCTestCase {
205213
206214 func testSaveAndRestoreCache( ) async throws {
207215 let ctx = model ( )
208- let initial = ChatSession ( ctx)
216+ let initial = ChatSession ( ctx, generateParameters : generationParameters )
209217 _ = try await initial. respond ( to: " hello " )
210218
211219 let url = FileManager . default. temporaryDirectory
@@ -214,37 +222,46 @@ public class ChatSessionTests: XCTestCase {
214222 try await initial. saveCache ( to: url)
215223
216224 let ( loadedCache, _) = try loadPromptCache ( url: url)
217- let restored = ChatSession ( ctx, cache: loadedCache)
225+ let restored = ChatSession (
226+ ctx, cache: loadedCache, generateParameters: generationParameters)
218227 let result = try await restored. respond ( to: " hello again " )
219228 XCTAssertGreaterThan ( result. count, targetLength, result)
220229 }
221230
222231 func testCurrentCacheNilForHistorySessionBeforeGeneration( ) async throws {
223232 // .history state should behave like .empty: no cache until first generation
224233 let history : [ Chat . Message ] = [ . user( " hello " ) , . assistant( " hi " ) ]
225- let session = ChatSession ( model ( ) , history: history)
226- let cache = await session. currentCache ( )
227- XCTAssertNil ( cache)
234+ let session = ChatSession (
235+ model ( ) , history: history, generateParameters: generationParameters)
236+ await session. withCache { cache in
237+ XCTAssertNil ( cache)
238+ }
228239 }
229240
230241 func testCurrentCacheNonNilForHistorySessionAfterGeneration( ) async throws {
231242 // after generation from .history state, cache transitions to .kvcache
232243 let history : [ Chat . Message ] = [ . user( " hello " ) , . assistant( " hi " ) ]
233- let session = ChatSession ( model ( ) , history: history)
244+ let session = ChatSession (
245+ model ( ) ,
246+ history: history,
247+ generateParameters: generationParameters)
234248 _ = try await session. respond ( to: " hello again " )
235- let cache = await session. currentCache ( )
236- XCTAssertNotNil ( cache)
249+ await session. withCache { cache in
250+ XCTAssertNotNil ( cache)
251+ }
237252 }
238253
239254 func testCurrentCacheNilAfterClear( ) async throws {
240255 // clear() resets to .empty; currentCache() should return nil again
241- let session = ChatSession ( model ( ) )
256+ let session = ChatSession ( model ( ) , generateParameters : generationParameters )
242257 _ = try await session. respond ( to: " hello " )
243- let cacheBeforeClear = await session. currentCache ( )
244- XCTAssertNotNil ( cacheBeforeClear)
258+ await session. withCache { cache in
259+ XCTAssertNotNil ( cache)
260+ }
245261 await session. clear ( )
246- let cacheAfterClear = await session. currentCache ( )
247- XCTAssertNil ( cacheAfterClear)
262+ await session. withCache { cache in
263+ XCTAssertNil ( cache)
264+ }
248265 }
249266
250267 /// something that looks like a view model
0 commit comments