diff --git a/FirebaseAI/Sources/Types/Internal/Live/AsyncWebSocket.swift b/FirebaseAI/Sources/Types/Internal/Live/AsyncWebSocket.swift index 38503be2ef5..ce216828f77 100644 --- a/FirebaseAI/Sources/Types/Internal/Live/AsyncWebSocket.swift +++ b/FirebaseAI/Sources/Types/Internal/Live/AsyncWebSocket.swift @@ -29,6 +29,9 @@ final class AsyncWebSocket: Sendable { private let continuationFinished = UnfairLock(false) private let closeError: UnfairLock + private let receivingTask = UnfairLock?>(nil) + private let pingTask = UnfairLock?>(nil) + init(urlSession: URLSession = GenAIURLSession.default, urlRequest: URLRequest) { webSocketTask = urlSession.webSocketTask(with: urlRequest) (stream, continuation) = AsyncThrowingStream @@ -45,6 +48,7 @@ final class AsyncWebSocket: Sendable { webSocketTask.resume() closeError.withLock { $0 = nil } startReceiving() + startPinging() return stream } @@ -66,20 +70,44 @@ final class AsyncWebSocket: Sendable { } private func startReceiving() { - Task { - while !Task.isCancelled && self.webSocketTask.isOpen && self.closeError.value() == nil { - do { - let message = try await webSocketTask.receive() - continuation.yield(message) - } catch { - if let error = webSocketTask.error as? NSError { - close( - code: webSocketTask.closeCode, - reason: webSocketTask.closeReason, + receivingTask.withLock { [weak self] task in + task?.cancel() + task = Task { [weak self] in + while let self, + !Task.isCancelled && self.webSocketTask.isOpen && self.closeError.value() == nil { + do { + let message = try await self.webSocketTask.receive() + self.continuation.yield(message) + } catch { + self.close( + code: self.webSocketTask.closeCode, + reason: self.webSocketTask.closeReason, underlyingError: error ) - } else { - close(code: webSocketTask.closeCode, reason: webSocketTask.closeReason) + } + } + } + } + } + + private func startPinging() { + pingTask.withLock { [weak self] task in + task?.cancel() + task = Task { [weak self] in + while let self, + !Task.isCancelled && self.webSocketTask.isOpen && self.closeError.value() == nil { + try? await Task.sleep(nanoseconds: 30 * 1_000_000_000) + guard !Task.isCancelled && self.webSocketTask.isOpen && self.closeError.value() == nil + else { return } + + self.webSocketTask.sendPing { [weak self] error in + if let error { + self?.close( + code: .abnormalClosure, + reason: nil, + underlyingError: error + ) + } } } } @@ -99,6 +127,8 @@ final class AsyncWebSocket: Sendable { } webSocketTask.cancel(with: code, reason: reason) + receivingTask.value()?.cancel() + pingTask.value()?.cancel() continuationFinished.withLock { isFinished in guard !isFinished else { return } diff --git a/FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift b/FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift index 3f4f0ef287e..5a595b3052e 100644 --- a/FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift +++ b/FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift @@ -20,6 +20,7 @@ import os.log // https://forums.swift.org/t/why-does-sending-a-sendable-value-risk-causing-data-races/73074 @preconcurrency import FirebaseAppCheckInterop @preconcurrency import FirebaseAuthInterop +private import FirebaseCoreInternal /// Facilitates communication with the backend for a ``LiveSession``. /// @@ -31,9 +32,16 @@ import os.log /// session is being reloaded. @available(watchOS, unavailable) actor LiveSessionService { - let responses: AsyncThrowingStream - private let responseContinuation: AsyncThrowingStream - .Continuation + private typealias StreamState = ( + responses: AsyncThrowingStream, + continuation: AsyncThrowingStream.Continuation, + isFinished: Bool + ) + private let streamState: UnfairLock + + nonisolated var responses: AsyncThrowingStream { + streamState.value().responses + } // to ensure messages are sent in order, since swift actors are reentrant private var messageQueue: AsyncStream @@ -71,7 +79,13 @@ actor LiveSessionService { toolConfig: ToolConfig?, systemInstruction: ModelContent?, requestOptions: RequestOptions) { - (responses, responseContinuation) = AsyncThrowingStream.makeStream() + let (responses, responseContinuation) = AsyncThrowingStream + .makeStream() + streamState = UnfairLock(( + responses: responses, + continuation: responseContinuation, + isFinished: false + )) (messageQueue, messageQueueContinuation) = AsyncStream.makeStream() self.modelResourceName = modelResourceName self.generationConfig = generationConfig @@ -92,7 +106,10 @@ actor LiveSessionService { // we only finish the streams when the actor deinits; while the actor is still in scope, the // user could continue using the streams via resumeSession (even after calling close) messageQueueContinuation.finish() - responseContinuation.finish() + streamState.withLock { state in + state.continuation.finish() + state.isFinished = true + } webSocket = nil responsesTask = nil @@ -115,7 +132,15 @@ actor LiveSessionService { /// /// This function will yield until the websocket is ready to communicate with the client. func connect(sessionResumption: SessionResumptionConfig? = nil) async throws { - close() + close(finishingStream: false) + + streamState.withLock { state in + if state.isFinished { + let (responses, responseContinuation) = AsyncThrowingStream + .makeStream() + state = (responses: responses, continuation: responseContinuation, isFinished: false) + } + } let stream = try await setupWebsocket() try await waitForSetupComplete(stream: stream, sessionResumption: sessionResumption) @@ -125,7 +150,10 @@ actor LiveSessionService { /// Cancel any running tasks and close the websocket. /// /// This method is idempotent; if it's already ran once, it will effectively be a no-op. - func close() { + /// + /// - Parameters: + /// - finishingStream: Whether to also finish the public ``responses`` stream. + func close(finishingStream: Bool = false) { responsesTask?.cancel() messageQueueTask?.cancel() webSocket?.disconnect() @@ -133,6 +161,17 @@ actor LiveSessionService { webSocket = nil responsesTask = nil messageQueueTask = nil + + // Finish the message queue so the messageQueueTask's `for await` loop exits immediately, + // rather than waiting for task cancellation to take effect. + messageQueueContinuation.finish() + + if finishingStream { + streamState.withLock { state in + state.continuation.finish() + state.isFinished = true + } + } } /// Performs the initial setup procedure for the model. @@ -163,7 +202,7 @@ actor LiveSessionService { try await webSocket.send(.data(data)) } catch { let error = LiveSessionSetupError(underlyingError: error) - close() + close(finishingStream: true) throw error } @@ -181,7 +220,7 @@ actor LiveSessionService { } } catch { if let error = mapWebsocketError(error) { - close() + close(finishingStream: true) throw error } // the user called close while setup was running @@ -221,7 +260,7 @@ actor LiveSessionService { } } catch { let error = LiveSessionSetupError(underlyingError: error) - close() + close(finishingStream: true) throw error } } @@ -233,20 +272,21 @@ actor LiveSessionService { /// - `responsesTask`: Listen to messages from the server and yield them through `responses`. /// - `messageQueueTask`: Listen to messages from the client and send them through the websocket. private func spawnMessageTasks(stream: MappedStream) { - guard let webSocket else { return } - // we create a new messageQueue since the iterator below will cancel the old one when the - // task is cancelled. this will cause issues when trying to restart a session via resumeSession + guard webSocket != nil else { return } + // Create a fresh messageQueue so the new task iterates its own stream. The old stream was + // finished in close(), and the old task was cancelled, so they are both done at this point. (messageQueue, messageQueueContinuation) = AsyncStream.makeStream() - responsesTask = Task { + responsesTask = Task { [weak self] in do { for try await message in stream { + guard let self else { return } #if DEBUG if #available(macOS 11.0, *) { - logServerMessage(message) + self.logServerMessage(message) } #endif - let response = try decodeServerMessage(message) + let response = try self.decodeServerMessage(message) if case .setupComplete = response.messageType { AILog.debug( @@ -262,25 +302,51 @@ actor LiveSessionService { ) } - responseContinuation.yield(liveMessage) + self.streamState.withLock { state in + state.continuation.yield(liveMessage) + } + } + } + // loop finished normally (websocket closed normally) + guard let self = self else { return } + if !Task.isCancelled { + self.streamState.withLock { state in + state.continuation.finish() + state.isFinished = true } } } catch { - if let error = mapWebsocketError(error) { - close() - responseContinuation.finish(throwing: error) + guard let self = self else { return } + if Task.isCancelled { return } + + if let error = self.mapWebsocketError(error) { + await self.close(finishingStream: false) + self.streamState.withLock { state in + state.continuation.finish(throwing: error) + state.isFinished = true + } + } else { + // normal closure (mapped to nil) + await self.close(finishingStream: true) } } } - messageQueueTask = Task { + let messageQueue = self.messageQueue + messageQueueTask = Task { [weak self] in for await message in messageQueue { - guard let data = encodeClientMessage(message) else { continue } + guard let self = self else { return } + guard let data = self.encodeClientMessage(message) else { continue } do { - try await webSocket.send(.data(data)) + try await self.webSocket?.send(.data(data)) } catch { AILog.error(code: .liveSessionFailedToSendClientMessage, error.localizedDescription) + await self.close(finishingStream: false) + self.streamState.withLock { state in + state.continuation.finish(throwing: error) + state.isFinished = true + } } } } @@ -288,7 +354,7 @@ actor LiveSessionService { #if DEBUG @available(macOS 11.0, *) - private func logServerMessage(_ message: Data) { + private nonisolated func logServerMessage(_ message: Data) { guard AILog.additionalLoggingEnabled() else { return } guard let message = JSONSerialization.prettyString(with: message) else { return } @@ -306,7 +372,7 @@ actor LiveSessionService { /// /// Some errors have public api alternatives. This function will ensure they're mapped /// accordingly. - private func mapWebsocketError(_ error: Error) -> Error? { + private nonisolated func mapWebsocketError(_ error: Error) -> Error? { if let error = error as? WebSocketClosedError { // only raise an error if the session didn't close normally (ie; the user calling close) if error.closeCode == .goingAway { @@ -331,7 +397,8 @@ actor LiveSessionService { /// Decodes a message from the server's websocket into a valid `BidiGenerateContentServerMessage`. /// /// Will throw an error if decoding fails. - private func decodeServerMessage(_ message: Data) throws -> BidiGenerateContentServerMessage { + private nonisolated func decodeServerMessage(_ message: Data) throws + -> BidiGenerateContentServerMessage { do { return try jsonDecoder.decode( BidiGenerateContentServerMessage.self, @@ -357,7 +424,8 @@ actor LiveSessionService { /// Encodes a message from the client into `Data` that can be sent through a websocket data frame. /// /// Will return `nil` if decoding fails, and log an error describing why. - private func encodeClientMessage(_ message: BidiGenerateContentClientMessage) -> Data? { + private nonisolated func encodeClientMessage(_ message: BidiGenerateContentClientMessage) + -> Data? { do { return try jsonEncoder.encode(message) } catch { diff --git a/FirebaseAI/Sources/Types/Public/Live/LiveSession.swift b/FirebaseAI/Sources/Types/Public/Live/LiveSession.swift index 8dd3fa813e9..8ffdc2806ff 100644 --- a/FirebaseAI/Sources/Types/Public/Live/LiveSession.swift +++ b/FirebaseAI/Sources/Types/Public/Live/LiveSession.swift @@ -35,6 +35,13 @@ public final class LiveSession: Sendable { self.service = service } + deinit { + let service = self.service + Task { + await service.close(finishingStream: true) + } + } + /// Response to a ``LiveServerToolCall`` received from the server. /// /// This method is used both for the realtime API and the incremental API. @@ -142,7 +149,7 @@ public final class LiveSession: Sendable { /// Attempting to receive content from a closed session will cause a /// ``LiveSessionUnexpectedClosureError`` error to be thrown. public func close() async { - await service.close() + await service.close(finishingStream: true) } /// Resumes an existing live session with the server. @@ -153,7 +160,7 @@ public final class LiveSession: Sendable { /// /// To optain a valid handle, ensure you pass an instance of /// ``SessionResumptionConfig`` to ``LiveGenerativeModel/connect(sessionResumption:)``, - /// and then listen for the hande provided from a ``LiveSessionResumptionUpdate`` + /// and then listen for the handle provided from a ``LiveSessionResumptionUpdate`` /// server message. /// /// - Parameters: diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift index d213fa51cb0..6f0396e072e 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift @@ -128,7 +128,13 @@ struct LiveSessionTests { // The model can't infer that we're done speaking until we send null bytes await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) - let text = try await session.collectNextAudioOutputTranscript() + let text: String + do { + text = try await session.collectNextAudioOutputTranscript() + } catch { + await session.close() + throw error + } await session.close() let modelResponse = text @@ -151,6 +157,7 @@ struct LiveSessionTests { let session = try await model.connect() guard let videoFile = NSDataAsset(name: "cat") else { Issue.record("Missing video file 'cat' in Assets") + await session.close() return } @@ -164,12 +171,19 @@ struct LiveSessionTests { // (they both respond with audio though) guard let audioFile = NSDataAsset(name: "hello") else { Issue.record("Missing audio file 'hello.wav' in Assets") + await session.close() return } await session.sendAudioRealtime(audioFile.data) await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) - let text = try await session.collectNextAudioOutputTranscript() + let text: String + do { + text = try await session.collectNextAudioOutputTranscript() + } catch { + await session.close() + throw error + } await session.close() let modelResponse = text @@ -193,7 +207,15 @@ struct LiveSessionTests { let session = try await model.connect() await session.sendTextRealtime("Alex") - guard let toolCall = try await session.collectNextToolCall() else { + let toolCall: LiveServerToolCall? + do { + toolCall = try await session.collectNextToolCall() + } catch { + await session.close() + throw error + } + guard let toolCall else { + await session.close() return } @@ -204,6 +226,7 @@ struct LiveSessionTests { #expect(functionCall.name == "getLastName") guard let response = getLastName(args: functionCall.args) else { + await session.close() return } await session.sendFunctionResponses([ @@ -213,9 +236,12 @@ struct LiveSessionTests { functionId: functionCall.functionId ), ]) - var text = try await session.collectNextAudioOutputTranscript() - if text.isEmpty { + let text: String + do { text = try await session.collectNextAudioOutputTranscript() + } catch { + await session.close() + throw error } await session.close() @@ -241,7 +267,15 @@ struct LiveSessionTests { await session.sendTextRealtime("My name is Alex.") - guard let newHandle = try await session.collectNextSessionHandle() else { + let newHandle: String? + do { + newHandle = try await session.collectNextSessionHandle() + } catch { + await session.close() + throw error + } + await session.close() + guard let newHandle else { return } @@ -249,11 +283,14 @@ struct LiveSessionTests { sessionResumption: SessionResumptionConfig(handle: newHandle) ) - await session.sendTextRealtime("What is my name?") + let text: String + do { + await session.sendTextRealtime("What is my name?") - var text = try await session.collectNextAudioOutputTranscript() - if text.isEmpty { text = try await session.collectNextAudioOutputTranscript() + } catch { + await session.close() + throw error } await session.close() @@ -283,15 +320,19 @@ struct LiveSessionTests { await session.sendTextRealtime("My name is Alex.") // re-connect without specifying the new handle (so it should be a new session) + await session.close() session = try await model.connect( sessionResumption: SessionResumptionConfig() ) - await session.sendTextRealtime("What is my name?") + let text: String + do { + await session.sendTextRealtime("What is my name?") - var text = try await session.collectNextAudioOutputTranscript() - if text.isEmpty { text = try await session.collectNextAudioOutputTranscript() + } catch { + await session.close() + throw error } await session.close() @@ -318,18 +359,29 @@ struct LiveSessionTests { await session.sendTextRealtime("My name is Alex.") - guard let newHandle = try await session.collectNextSessionHandle() else { - return + let newHandle: String? + do { + newHandle = try await session.collectNextSessionHandle() + } catch { + await session.close() + throw error } await session.close() + guard let newHandle else { + return + } + try await session.resumeSession(sessionResumption: SessionResumptionConfig(handle: newHandle)) - await session.sendTextRealtime("What is my name?") + let text: String + do { + await session.sendTextRealtime("What is my name?") - var text = try await session.collectNextAudioOutputTranscript() - if text.isEmpty { text = try await session.collectNextAudioOutputTranscript() + } catch { + await session.close() + throw error } await session.close() @@ -365,6 +417,7 @@ struct LiveSessionTests { await session.sendTextRealtime("Alex") guard let toolCall = try await session.collectNextToolCall() else { + await session.close() return } @@ -374,11 +427,16 @@ struct LiveSessionTests { let functionCall = try #require(functionCalls.first) let id = try #require(functionCall.functionId) - await session.sendTextRealtime("Actually, I don't care about the last name of Alex anymore.") + do { + await session.sendTextRealtime("Actually, I don't care about the last name of Alex anymore.") - for try await cancellation in session.responsesOf(LiveServerToolCallCancellation.self) { - #expect(cancellation.ids == [id]) - break + for try await cancellation in session.responsesOf(LiveServerToolCallCancellation.self) { + #expect(cancellation.ids == [id]) + break + } + } catch { + await session.close() + throw error } await session.close() @@ -398,23 +456,29 @@ struct LiveSessionTests { try await retry(times: 3, delayInSeconds: 2.0) { let session = try await model.connect() - await session.sendAudioRealtime(audioFile.data) - await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) - - // Wait a second to allow the model to start generating (and cause a proper interruption) - try await Task.sleep(nanoseconds: oneSecondInNanoseconds) - await session.sendAudioRealtime(audioFile.data) - await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) - - for try await content in session.responsesOf(LiveServerContent.self) { - if content.wasInterrupted { - break - } - - if content.isTurnComplete { - throw NoInterruptionError() + do { + await session.sendAudioRealtime(audioFile.data) + await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) + + // Wait a second to allow the model to start generating (and cause a proper interruption) + try await Task.sleep(nanoseconds: oneSecondInNanoseconds) + await session.sendAudioRealtime(audioFile.data) + await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count)) + + for try await content in session.responsesOf(LiveServerContent.self) { + if content.wasInterrupted { + break + } + + if content.isTurnComplete { + throw NoInterruptionError() + } } + } catch { + await session.close() + throw error } + await session.close() } } @@ -427,13 +491,15 @@ struct LiveSessionTests { ) let session = try await model.connect() - await session.sendContent("Does five plus") - await session.sendContent(" five equal ten?", turnComplete: true) + let text: String + do { + await session.sendContent("Does five plus") + await session.sendContent(" five equal ten?", turnComplete: true) - var text = try await session.collectNextAudioOutputTranscript() - if text.isEmpty { - // The model sometimes sends an empty text response first text = try await session.collectNextAudioOutputTranscript() + } catch { + await session.close() + throw error } await session.close() @@ -469,6 +535,12 @@ private extension LiveSession { /// Once the model signals that its turn is complete, the function will return /// a string concatenated of all the `LiveAudioTranscription`s. func collectNextAudioOutputTranscript() async throws -> String { + // The model sometimes sends an empty text response first + let text = try await collectNextTurn() + return text.isEmpty ? try await collectNextTurn() : text + } + + private func collectNextTurn() async throws -> String { var text = "" for try await content in responsesOf(LiveServerContent.self) {