|
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | 15 | #if compiler(>=6.2.3) |
| 16 | + private import FirebaseCoreInternal |
| 17 | + |
16 | 18 | final class HybridModelSession: _ModelSession { |
17 | | - private let primary: any _ModelSession |
18 | | - private let secondary: any _ModelSession |
| 19 | + private let primaryModel: any LanguageModel |
| 20 | + private let secondaryModel: any LanguageModel |
| 21 | + private let tools: [any ToolRepresentable]? |
| 22 | + private let instructions: String? |
| 23 | + |
| 24 | + typealias SessionState = (primary: (any _ModelSession)?, secondary: (any _ModelSession)?) |
| 25 | + private let lock: UnfairLock<SessionState> |
| 26 | + |
| 27 | + init(primaryModel: any LanguageModel, secondaryModel: any LanguageModel, |
| 28 | + tools: [any ToolRepresentable]?, instructions: String?) { |
| 29 | + self.primaryModel = primaryModel |
| 30 | + self.secondaryModel = secondaryModel |
| 31 | + self.tools = tools |
| 32 | + self.instructions = instructions |
| 33 | + lock = UnfairLock((primary: nil, secondary: nil)) |
| 34 | + } |
19 | 35 |
|
20 | | - init(primary: any _ModelSession, secondary: any _ModelSession) { |
21 | | - self.primary = primary |
22 | | - self.secondary = secondary |
| 36 | + enum SessionModel { |
| 37 | + case primaryModel |
| 38 | + case secondaryModel |
23 | 39 | } |
24 | 40 |
|
25 | 41 | /// Returns `true` if the session has history (i.e., it has already had one or more chat turns). |
26 | 42 | /// |
27 | 43 | /// > Important: This property is for **internal use only** and may change at any time. |
28 | 44 | var _hasHistory: Bool { |
29 | | - return primary._hasHistory || secondary._hasHistory |
| 45 | + return lock.withLock { state in |
| 46 | + state.primary?._hasHistory == true || state.secondary?._hasHistory == true |
| 47 | + } |
30 | 48 | } |
31 | 49 |
|
32 | 50 | /// Sends a prompt to the model and returns a ``_ModelSessionResponse``. |
|
45 | 63 | return try await TaskLocals.$isHybridRequest.withValue(true) { |
46 | 64 | // If the secondary session contains history then a previous fallback occurred. |
47 | 65 | // Stick with the secondary session to maintain conversation consistency. |
48 | | - if secondary._hasHistory { |
49 | | - return try await secondary._respond( |
| 66 | + let useSecondary = lock.withLock { state in |
| 67 | + state.secondary?._hasHistory == true |
| 68 | + } |
| 69 | + if useSecondary { |
| 70 | + let secondarySession = try getSession(for: .secondaryModel) |
| 71 | + return try await secondarySession._respond( |
50 | 72 | to: prompt, |
51 | 73 | schema: schema, |
52 | 74 | includeSchemaInPrompt: includeSchemaInPrompt, |
|
56 | 78 |
|
57 | 79 | do { |
58 | 80 | // First try the primary session. |
59 | | - return try await primary._respond( |
| 81 | + let primarySession = try getSession(for: .primaryModel) |
| 82 | + return try await primarySession._respond( |
60 | 83 | to: prompt, |
61 | 84 | schema: schema, |
62 | 85 | includeSchemaInPrompt: includeSchemaInPrompt, |
63 | 86 | options: options |
64 | 87 | ) |
65 | 88 | } catch { |
| 89 | + if Task.isCancelled || error is CancellationError { |
| 90 | + throw error |
| 91 | + } |
| 92 | + |
66 | 93 | // Do not fallback to second session if the primary session contains history. |
67 | | - if primary._hasHistory { |
| 94 | + let primaryHasHistory = lock.withLock { state in |
| 95 | + state.primary?._hasHistory == true |
| 96 | + } |
| 97 | + if primaryHasHistory { |
68 | 98 | throw error |
69 | 99 | } |
70 | 100 |
|
71 | 101 | // Fallback to the second session if the first fails or is unavailable. |
72 | | - return try await secondary._respond( |
| 102 | + let secondarySession = try getSession(for: .secondaryModel) |
| 103 | + return try await secondarySession._respond( |
73 | 104 | to: prompt, |
74 | 105 | schema: schema, |
75 | 106 | includeSchemaInPrompt: includeSchemaInPrompt, |
|
95 | 126 | return TaskLocals.$isHybridRequest.withValue(true) { |
96 | 127 | // If the secondary session contains history then a previous fallback occurred. |
97 | 128 | // Stick with the secondary session to maintain conversation consistency. |
98 | | - if secondary._hasHistory { |
99 | | - return secondary._streamResponse( |
100 | | - to: prompt, |
101 | | - schema: schema, |
102 | | - includeSchemaInPrompt: includeSchemaInPrompt, |
103 | | - options: options |
104 | | - ) |
| 129 | + let useSecondary = lock.withLock { state in |
| 130 | + state.secondary?._hasHistory == true |
105 | 131 | } |
106 | | - |
107 | | - return AsyncThrowingStream { continuation in |
108 | | - let task = Task { |
109 | | - // First try the primary session. |
110 | | - let stream = primary._streamResponse( |
| 132 | + if useSecondary { |
| 133 | + do { |
| 134 | + let secondarySession = try getSession(for: .secondaryModel) |
| 135 | + return secondarySession._streamResponse( |
111 | 136 | to: prompt, |
112 | 137 | schema: schema, |
113 | 138 | includeSchemaInPrompt: includeSchemaInPrompt, |
114 | 139 | options: options |
115 | 140 | ) |
| 141 | + } catch { |
| 142 | + return AsyncThrowingStream { continuation in |
| 143 | + continuation.finish(throwing: error) |
| 144 | + } |
| 145 | + } |
| 146 | + } |
116 | 147 |
|
117 | | - var didYield = false |
| 148 | + return AsyncThrowingStream { continuation in |
| 149 | + let task = Task { |
118 | 150 | do { |
119 | | - for try await snapshot in stream { |
120 | | - didYield = true |
121 | | - continuation.yield(snapshot) |
| 151 | + // First try the primary session. |
| 152 | + let primarySession = try self.getSession(for: .primaryModel) |
| 153 | + let stream = primarySession._streamResponse( |
| 154 | + to: prompt, |
| 155 | + schema: schema, |
| 156 | + includeSchemaInPrompt: includeSchemaInPrompt, |
| 157 | + options: options |
| 158 | + ) |
| 159 | + |
| 160 | + var didYield = false |
| 161 | + do { |
| 162 | + for try await snapshot in stream { |
| 163 | + didYield = true |
| 164 | + continuation.yield(snapshot) |
| 165 | + } |
| 166 | + continuation.finish() |
| 167 | + } catch { |
| 168 | + if Task.isCancelled || error is CancellationError { |
| 169 | + continuation.finish(throwing: error) |
| 170 | + return |
| 171 | + } |
| 172 | + |
| 173 | + // Do not fallback to second session if the primary session contains history or has |
| 174 | + // already yielded data. |
| 175 | + let primaryHasHistory = self.lock.withLock { state in |
| 176 | + state.primary?._hasHistory == true |
| 177 | + } |
| 178 | + if didYield || primaryHasHistory { |
| 179 | + continuation.finish(throwing: error) |
| 180 | + return |
| 181 | + } |
| 182 | + |
| 183 | + // Fallback to the second session if the first fails or is unavailable. |
| 184 | + let secondarySession = try self.getSession(for: .secondaryModel) |
| 185 | + let stream = secondarySession._streamResponse( |
| 186 | + to: prompt, |
| 187 | + schema: schema, |
| 188 | + includeSchemaInPrompt: includeSchemaInPrompt, |
| 189 | + options: options |
| 190 | + ) |
| 191 | + |
| 192 | + do { |
| 193 | + for try await snapshot in stream { |
| 194 | + continuation.yield(snapshot) |
| 195 | + } |
| 196 | + continuation.finish() |
| 197 | + } catch { |
| 198 | + continuation.finish(throwing: error) |
| 199 | + } |
122 | 200 | } |
123 | | - continuation.finish() |
124 | 201 | } catch { |
125 | | - // Do not fallback to second session if the primary session contains history or has |
126 | | - // already yielded data. |
127 | | - if didYield || primary._hasHistory { |
| 202 | + if Task.isCancelled || error is CancellationError { |
128 | 203 | continuation.finish(throwing: error) |
129 | 204 | return |
130 | 205 | } |
131 | 206 |
|
| 207 | + // Failure to create primary session. |
132 | 208 | // Fallback to the second session if the first fails or is unavailable. |
133 | | - let stream = secondary._streamResponse( |
134 | | - to: prompt, |
135 | | - schema: schema, |
136 | | - includeSchemaInPrompt: includeSchemaInPrompt, |
137 | | - options: options |
138 | | - ) |
139 | | - |
140 | 209 | do { |
| 210 | + let secondarySession = try self.getSession(for: .secondaryModel) |
| 211 | + let stream = secondarySession._streamResponse( |
| 212 | + to: prompt, |
| 213 | + schema: schema, |
| 214 | + includeSchemaInPrompt: includeSchemaInPrompt, |
| 215 | + options: options |
| 216 | + ) |
| 217 | + |
141 | 218 | for try await snapshot in stream { |
142 | 219 | continuation.yield(snapshot) |
143 | 220 | } |
|
151 | 228 | } |
152 | 229 | } |
153 | 230 | } |
| 231 | + |
| 232 | + private func getSession(for model: SessionModel) throws -> any _ModelSession { |
| 233 | + let languageModel = (model == .primaryModel) ? primaryModel : secondaryModel |
| 234 | + |
| 235 | + // 1. Check if it exists under lock. |
| 236 | + let existing = lock.withLock { state in |
| 237 | + (model == .primaryModel) ? state.primary : state.secondary |
| 238 | + } |
| 239 | + if let existing { |
| 240 | + return existing |
| 241 | + } |
| 242 | + |
| 243 | + // 2. Create it outside the lock. |
| 244 | + let session = try languageModel._startSession(tools: tools, instructions: instructions) |
| 245 | + |
| 246 | + // 3. Try to store it under lock. |
| 247 | + return lock.withLock { state in |
| 248 | + if model == .primaryModel { |
| 249 | + if let existing = state.primary { return existing } |
| 250 | + state.primary = session |
| 251 | + } else { |
| 252 | + if let existing = state.secondary { return existing } |
| 253 | + state.secondary = session |
| 254 | + } |
| 255 | + return session |
| 256 | + } |
| 257 | + } |
154 | 258 | } |
155 | 259 | #endif // compiler(>=6.2.3) |
0 commit comments