Skip to content

Commit 223f04d

Browse files
authored
[AI] Lazily initialize sessions in HybridModelSession (#16119)
1 parent 59d8956 commit 223f04d

5 files changed

Lines changed: 456 additions & 91 deletions

File tree

FirebaseAI/Sources/Types/Internal/HybridModelSession.swift

Lines changed: 142 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,38 @@
1313
// limitations under the License.
1414

1515
#if compiler(>=6.2.3)
16+
private import FirebaseCoreInternal
17+
1618
final class HybridModelSession: _ModelSession {
17-
private let primary: any _ModelSession
18-
private let secondary: any _ModelSession
19+
private let primaryModel: any LanguageModel
20+
private let secondaryModel: any LanguageModel
21+
private let tools: [any ToolRepresentable]?
22+
private let instructions: String?
23+
24+
typealias SessionState = (primary: (any _ModelSession)?, secondary: (any _ModelSession)?)
25+
private let lock: UnfairLock<SessionState>
26+
27+
init(primaryModel: any LanguageModel, secondaryModel: any LanguageModel,
28+
tools: [any ToolRepresentable]?, instructions: String?) {
29+
self.primaryModel = primaryModel
30+
self.secondaryModel = secondaryModel
31+
self.tools = tools
32+
self.instructions = instructions
33+
lock = UnfairLock((primary: nil, secondary: nil))
34+
}
1935

20-
init(primary: any _ModelSession, secondary: any _ModelSession) {
21-
self.primary = primary
22-
self.secondary = secondary
36+
enum SessionModel {
37+
case primaryModel
38+
case secondaryModel
2339
}
2440

2541
/// Returns `true` if the session has history (i.e., it has already had one or more chat turns).
2642
///
2743
/// > Important: This property is for **internal use only** and may change at any time.
2844
var _hasHistory: Bool {
29-
return primary._hasHistory || secondary._hasHistory
45+
return lock.withLock { state in
46+
state.primary?._hasHistory == true || state.secondary?._hasHistory == true
47+
}
3048
}
3149

3250
/// Sends a prompt to the model and returns a ``_ModelSessionResponse``.
@@ -45,8 +63,12 @@
4563
return try await TaskLocals.$isHybridRequest.withValue(true) {
4664
// If the secondary session contains history then a previous fallback occurred.
4765
// Stick with the secondary session to maintain conversation consistency.
48-
if secondary._hasHistory {
49-
return try await secondary._respond(
66+
let useSecondary = lock.withLock { state in
67+
state.secondary?._hasHistory == true
68+
}
69+
if useSecondary {
70+
let secondarySession = try getSession(for: .secondaryModel)
71+
return try await secondarySession._respond(
5072
to: prompt,
5173
schema: schema,
5274
includeSchemaInPrompt: includeSchemaInPrompt,
@@ -56,20 +78,29 @@
5678

5779
do {
5880
// First try the primary session.
59-
return try await primary._respond(
81+
let primarySession = try getSession(for: .primaryModel)
82+
return try await primarySession._respond(
6083
to: prompt,
6184
schema: schema,
6285
includeSchemaInPrompt: includeSchemaInPrompt,
6386
options: options
6487
)
6588
} catch {
89+
if Task.isCancelled || error is CancellationError {
90+
throw error
91+
}
92+
6693
// Do not fallback to second session if the primary session contains history.
67-
if primary._hasHistory {
94+
let primaryHasHistory = lock.withLock { state in
95+
state.primary?._hasHistory == true
96+
}
97+
if primaryHasHistory {
6898
throw error
6999
}
70100

71101
// Fallback to the second session if the first fails or is unavailable.
72-
return try await secondary._respond(
102+
let secondarySession = try getSession(for: .secondaryModel)
103+
return try await secondarySession._respond(
73104
to: prompt,
74105
schema: schema,
75106
includeSchemaInPrompt: includeSchemaInPrompt,
@@ -95,49 +126,95 @@
95126
return TaskLocals.$isHybridRequest.withValue(true) {
96127
// If the secondary session contains history then a previous fallback occurred.
97128
// Stick with the secondary session to maintain conversation consistency.
98-
if secondary._hasHistory {
99-
return secondary._streamResponse(
100-
to: prompt,
101-
schema: schema,
102-
includeSchemaInPrompt: includeSchemaInPrompt,
103-
options: options
104-
)
129+
let useSecondary = lock.withLock { state in
130+
state.secondary?._hasHistory == true
105131
}
106-
107-
return AsyncThrowingStream { continuation in
108-
let task = Task {
109-
// First try the primary session.
110-
let stream = primary._streamResponse(
132+
if useSecondary {
133+
do {
134+
let secondarySession = try getSession(for: .secondaryModel)
135+
return secondarySession._streamResponse(
111136
to: prompt,
112137
schema: schema,
113138
includeSchemaInPrompt: includeSchemaInPrompt,
114139
options: options
115140
)
141+
} catch {
142+
return AsyncThrowingStream { continuation in
143+
continuation.finish(throwing: error)
144+
}
145+
}
146+
}
116147

117-
var didYield = false
148+
return AsyncThrowingStream { continuation in
149+
let task = Task {
118150
do {
119-
for try await snapshot in stream {
120-
didYield = true
121-
continuation.yield(snapshot)
151+
// First try the primary session.
152+
let primarySession = try self.getSession(for: .primaryModel)
153+
let stream = primarySession._streamResponse(
154+
to: prompt,
155+
schema: schema,
156+
includeSchemaInPrompt: includeSchemaInPrompt,
157+
options: options
158+
)
159+
160+
var didYield = false
161+
do {
162+
for try await snapshot in stream {
163+
didYield = true
164+
continuation.yield(snapshot)
165+
}
166+
continuation.finish()
167+
} catch {
168+
if Task.isCancelled || error is CancellationError {
169+
continuation.finish(throwing: error)
170+
return
171+
}
172+
173+
// Do not fallback to second session if the primary session contains history or has
174+
// already yielded data.
175+
let primaryHasHistory = self.lock.withLock { state in
176+
state.primary?._hasHistory == true
177+
}
178+
if didYield || primaryHasHistory {
179+
continuation.finish(throwing: error)
180+
return
181+
}
182+
183+
// Fallback to the second session if the first fails or is unavailable.
184+
let secondarySession = try self.getSession(for: .secondaryModel)
185+
let stream = secondarySession._streamResponse(
186+
to: prompt,
187+
schema: schema,
188+
includeSchemaInPrompt: includeSchemaInPrompt,
189+
options: options
190+
)
191+
192+
do {
193+
for try await snapshot in stream {
194+
continuation.yield(snapshot)
195+
}
196+
continuation.finish()
197+
} catch {
198+
continuation.finish(throwing: error)
199+
}
122200
}
123-
continuation.finish()
124201
} catch {
125-
// Do not fallback to second session if the primary session contains history or has
126-
// already yielded data.
127-
if didYield || primary._hasHistory {
202+
if Task.isCancelled || error is CancellationError {
128203
continuation.finish(throwing: error)
129204
return
130205
}
131206

207+
// Failure to create primary session.
132208
// Fallback to the second session if the first fails or is unavailable.
133-
let stream = secondary._streamResponse(
134-
to: prompt,
135-
schema: schema,
136-
includeSchemaInPrompt: includeSchemaInPrompt,
137-
options: options
138-
)
139-
140209
do {
210+
let secondarySession = try self.getSession(for: .secondaryModel)
211+
let stream = secondarySession._streamResponse(
212+
to: prompt,
213+
schema: schema,
214+
includeSchemaInPrompt: includeSchemaInPrompt,
215+
options: options
216+
)
217+
141218
for try await snapshot in stream {
142219
continuation.yield(snapshot)
143220
}
@@ -151,5 +228,32 @@
151228
}
152229
}
153230
}
231+
232+
private func getSession(for model: SessionModel) throws -> any _ModelSession {
233+
let languageModel = (model == .primaryModel) ? primaryModel : secondaryModel
234+
235+
// 1. Check if it exists under lock.
236+
let existing = lock.withLock { state in
237+
(model == .primaryModel) ? state.primary : state.secondary
238+
}
239+
if let existing {
240+
return existing
241+
}
242+
243+
// 2. Create it outside the lock.
244+
let session = try languageModel._startSession(tools: tools, instructions: instructions)
245+
246+
// 3. Try to store it under lock.
247+
return lock.withLock { state in
248+
if model == .primaryModel {
249+
if let existing = state.primary { return existing }
250+
state.primary = session
251+
} else {
252+
if let existing = state.secondary { return existing }
253+
state.secondary = session
254+
}
255+
return session
256+
}
257+
}
154258
}
155259
#endif // compiler(>=6.2.3)

FirebaseAI/Sources/Types/Public/HybridModel.swift

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,41 +46,12 @@
4646
/// > Important: This method is for **internal use only** and may change at any time.
4747
public func _startSession(tools: [any ToolRepresentable]?,
4848
instructions: String?) throws -> any _ModelSession {
49-
var primarySession: (any _ModelSession)?
50-
var primaryError: Error?
51-
var secondarySession: (any _ModelSession)?
52-
var secondaryError: Error?
53-
54-
do {
55-
primarySession = try primary._startSession(tools: tools, instructions: instructions)
56-
} catch {
57-
primaryError = error
58-
}
59-
60-
do {
61-
secondarySession = try secondary._startSession(tools: tools, instructions: instructions)
62-
} catch {
63-
secondaryError = error
64-
}
65-
66-
if let primarySession, let secondarySession {
67-
return HybridModelSession(primary: primarySession, secondary: secondarySession)
68-
} else if let primarySession {
69-
return primarySession
70-
} else if let secondarySession {
71-
return secondarySession
72-
} else {
73-
let context = GenerativeModelSession.GenerationError.Context(
74-
debugDescription: """
75-
Failed to start session for model "\(primary._modelName)" with error: \
76-
\(primaryError?.localizedDescription ?? "unknown error");
77-
Failed to start session for model "\(secondary._modelName)" with error: \
78-
\(secondaryError?.localizedDescription ?? "unknown error")
79-
"""
80-
)
81-
82-
throw GenerativeModelSession.GenerationError.assetsUnavailable(context)
83-
}
49+
return HybridModelSession(
50+
primaryModel: primary,
51+
secondaryModel: secondary,
52+
tools: tools,
53+
instructions: instructions
54+
)
8455
}
8556
}
8657
#endif // compiler(>=6.2.3)

FirebaseAI/Tests/TestApp/Tests/Integration/GenerativeModelSessionHybridTests.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,41 @@
208208
#expect(response.rawResponse.modelVersion == validModel._modelName)
209209
}
210210

211+
@Test(arguments: [InstanceConfig.vertexAI_v1beta_global])
212+
@available(iOS 26.0, macOS 26.0, visionOS 26.0, *)
213+
@available(tvOS, unavailable)
214+
@available(watchOS, unavailable)
215+
func streamResponseText_fallbackOnFoundationModelsError(_ config: InstanceConfig) async throws {
216+
let firebaseAI = FirebaseAI.componentInstance(config)
217+
let systemModel = FirebaseAI.SystemLanguageModel.default
218+
let geminiModel = firebaseAI.geminiModel(name: ModelNames.gemini2_5_FlashLite)
219+
let session = firebaseAI.generativeModelSession(
220+
model: HybridModel(primary: systemModel, secondary: geminiModel)
221+
)
222+
let prompt = "In one sentence, why is the sky blue?"
223+
224+
let stream = session.streamResponse(to: prompt)
225+
226+
var receivedTexts = [String]()
227+
var isComplete = false
228+
for try await snapshot in stream {
229+
let partial = snapshot.content
230+
receivedTexts.append(partial)
231+
isComplete = snapshot.rawContent.isComplete
232+
}
233+
#expect(isComplete)
234+
235+
let response = try await stream.collect()
236+
let content = response.content
237+
#expect(!content.isEmpty)
238+
239+
if await foundationModelsIsAvailable() {
240+
#expect(response.rawResponse.modelVersion == systemModel._modelName)
241+
} else {
242+
#expect(response.rawResponse.modelVersion == geminiModel._modelName)
243+
}
244+
}
245+
211246
/// Returns `true` if `FoundationModels.SystemLanguageModel` is available.
212247
///
213248
/// This is a workaround for `SystemLanguageModel.isAvailable`, which returns `true` if *any*

0 commit comments

Comments
 (0)