Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions FirebaseAI/Sources/Types/Internal/Live/AsyncWebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ final class AsyncWebSocket: Sendable {
private let continuationFinished = UnfairLock<Bool>(false)
private let closeError: UnfairLock<WebSocketClosedError?>

private let receivingTask = UnfairLock<Task<Void, Never>?>(nil)
private let pingTask = UnfairLock<Task<Void, Never>?>(nil)

init(urlSession: URLSession = GenAIURLSession.default, urlRequest: URLRequest) {
webSocketTask = urlSession.webSocketTask(with: urlRequest)
(stream, continuation) = AsyncThrowingStream<URLSessionWebSocketTask.Message, Error>
Expand All @@ -45,6 +48,7 @@ final class AsyncWebSocket: Sendable {
webSocketTask.resume()
closeError.withLock { $0 = nil }
startReceiving()
startPinging()
return stream
}

Expand All @@ -66,20 +70,44 @@ final class AsyncWebSocket: Sendable {
}

private func startReceiving() {
Task {
while !Task.isCancelled && self.webSocketTask.isOpen && self.closeError.value() == nil {
do {
let message = try await webSocketTask.receive()
continuation.yield(message)
} catch {
if let error = webSocketTask.error as? NSError {
close(
code: webSocketTask.closeCode,
reason: webSocketTask.closeReason,
receivingTask.withLock { [weak self] task in
task?.cancel()
task = Task { [weak self] in
while let self,
!Task.isCancelled && self.webSocketTask.isOpen && self.closeError.value() == nil {
do {
let message = try await self.webSocketTask.receive()
self.continuation.yield(message)
} catch {
self.close(
code: self.webSocketTask.closeCode,
reason: self.webSocketTask.closeReason,
underlyingError: error
)
} else {
close(code: webSocketTask.closeCode, reason: webSocketTask.closeReason)
}
}
}
}
}

private func startPinging() {
pingTask.withLock { [weak self] task in
task?.cancel()
task = Task { [weak self] in
while let self,
!Task.isCancelled && self.webSocketTask.isOpen && self.closeError.value() == nil {
try? await Task.sleep(nanoseconds: 30 * 1_000_000_000)
Comment thread
andrewheard marked this conversation as resolved.
guard !Task.isCancelled && self.webSocketTask.isOpen && self.closeError.value() == nil
else { return }

self.webSocketTask.sendPing { [weak self] error in
if let error {
self?.close(
code: .abnormalClosure,
reason: nil,
underlyingError: error
)
}
}
}
}
Expand All @@ -99,6 +127,8 @@ final class AsyncWebSocket: Sendable {
}

webSocketTask.cancel(with: code, reason: reason)
receivingTask.value()?.cancel()
pingTask.value()?.cancel()

continuationFinished.withLock { isFinished in
guard !isFinished else { return }
Expand Down
104 changes: 83 additions & 21 deletions FirebaseAI/Sources/Types/Internal/Live/LiveSessionService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import os.log
// https://forums.swift.org/t/why-does-sending-a-sendable-value-risk-causing-data-races/73074
@preconcurrency import FirebaseAppCheckInterop
@preconcurrency import FirebaseAuthInterop
private import FirebaseCoreInternal

/// Facilitates communication with the backend for a ``LiveSession``.
///
Expand All @@ -31,9 +32,16 @@ import os.log
/// session is being reloaded.
@available(watchOS, unavailable)
actor LiveSessionService {
let responses: AsyncThrowingStream<LiveServerMessage, Error>
private let responseContinuation: AsyncThrowingStream<LiveServerMessage, Error>
.Continuation
private typealias StreamState = (
responses: AsyncThrowingStream<LiveServerMessage, Error>,
continuation: AsyncThrowingStream<LiveServerMessage, Error>.Continuation,
isFinished: Bool
)
private let streamState: UnfairLock<StreamState>

nonisolated var responses: AsyncThrowingStream<LiveServerMessage, Error> {
streamState.value().responses
}

// to ensure messages are sent in order, since swift actors are reentrant
private var messageQueue: AsyncStream<BidiGenerateContentClientMessage>
Expand Down Expand Up @@ -71,7 +79,13 @@ actor LiveSessionService {
toolConfig: ToolConfig?,
systemInstruction: ModelContent?,
requestOptions: RequestOptions) {
(responses, responseContinuation) = AsyncThrowingStream.makeStream()
let (responses, responseContinuation) = AsyncThrowingStream<LiveServerMessage, Error>
.makeStream()
streamState = UnfairLock((
responses: responses,
continuation: responseContinuation,
isFinished: false
))
(messageQueue, messageQueueContinuation) = AsyncStream.makeStream()
self.modelResourceName = modelResourceName
self.generationConfig = generationConfig
Expand All @@ -92,7 +106,10 @@ actor LiveSessionService {
// we only finish the streams when the actor deinits; while the actor is still in scope, the
// user could continue using the streams via resumeSession (even after calling close)
messageQueueContinuation.finish()
responseContinuation.finish()
streamState.withLock { state in
state.continuation.finish()
state.isFinished = true
}

webSocket = nil
responsesTask = nil
Expand All @@ -115,7 +132,15 @@ actor LiveSessionService {
///
/// This function will yield until the websocket is ready to communicate with the client.
func connect(sessionResumption: SessionResumptionConfig? = nil) async throws {
close()
close(finishingStream: false)

streamState.withLock { state in
if state.isFinished {
let (responses, responseContinuation) = AsyncThrowingStream<LiveServerMessage, Error>
.makeStream()
state = (responses: responses, continuation: responseContinuation, isFinished: false)
}
}

let stream = try await setupWebsocket()
try await waitForSetupComplete(stream: stream, sessionResumption: sessionResumption)
Expand All @@ -125,14 +150,24 @@ actor LiveSessionService {
/// Cancel any running tasks and close the websocket.
///
/// This method is idempotent; if it's already ran once, it will effectively be a no-op.
func close() {
///
/// - Parameters:
/// - finishingStream: Whether to also finish the public ``responses`` stream.
func close(finishingStream: Bool = false) {
responsesTask?.cancel()
messageQueueTask?.cancel()
webSocket?.disconnect()

webSocket = nil
responsesTask = nil
messageQueueTask = nil

if finishingStream {
streamState.withLock { state in
state.continuation.finish()
state.isFinished = true
}
}
}

/// Performs the initial setup procedure for the model.
Expand Down Expand Up @@ -163,7 +198,7 @@ actor LiveSessionService {
try await webSocket.send(.data(data))
} catch {
let error = LiveSessionSetupError(underlyingError: error)
close()
close(finishingStream: true)
throw error
}

Expand All @@ -181,7 +216,7 @@ actor LiveSessionService {
}
} catch {
if let error = mapWebsocketError(error) {
close()
close(finishingStream: true)
throw error
}
// the user called close while setup was running
Expand Down Expand Up @@ -221,7 +256,7 @@ actor LiveSessionService {
}
} catch {
let error = LiveSessionSetupError(underlyingError: error)
close()
close(finishingStream: true)
throw error
}
}
Expand All @@ -233,20 +268,21 @@ actor LiveSessionService {
/// - `responsesTask`: Listen to messages from the server and yield them through `responses`.
/// - `messageQueueTask`: Listen to messages from the client and send them through the websocket.
private func spawnMessageTasks(stream: MappedStream<URLSessionWebSocketTask.Message, Data>) {
guard let webSocket else { return }
guard webSocket != nil else { return }
// we create a new messageQueue since the iterator below will cancel the old one when the
// task is cancelled. this will cause issues when trying to restart a session via resumeSession
(messageQueue, messageQueueContinuation) = AsyncStream.makeStream()

responsesTask = Task {
responsesTask = Task { [weak self] in
do {
for try await message in stream {
guard let self else { return }
#if DEBUG
if #available(macOS 11.0, *) {
logServerMessage(message)
await self.logServerMessage(message)
}
#endif
let response = try decodeServerMessage(message)
let response = try await self.decodeServerMessage(message)
Comment thread
andrewheard marked this conversation as resolved.
Outdated

if case .setupComplete = response.messageType {
AILog.debug(
Expand All @@ -262,25 +298,51 @@ actor LiveSessionService {
)
}

responseContinuation.yield(liveMessage)
self.streamState.withLock { state in
state.continuation.yield(liveMessage)
}
}
}
// loop finished normally (websocket closed normally)
guard let self = self else { return }
if !Task.isCancelled {
self.streamState.withLock { state in
state.continuation.finish()
state.isFinished = true
}
}
} catch {
if let error = mapWebsocketError(error) {
close()
responseContinuation.finish(throwing: error)
guard let self = self else { return }
if Task.isCancelled { return }

if let error = await self.mapWebsocketError(error) {
await self.close(finishingStream: true)
self.streamState.withLock { state in
state.continuation.finish(throwing: error)
state.isFinished = true
}
Comment thread
andrewheard marked this conversation as resolved.
Outdated
} else {
// normal closure (mapped to nil)
await self.close(finishingStream: true)
}
}
}

messageQueueTask = Task {
let messageQueue = self.messageQueue
messageQueueTask = Task { [weak self] in
for await message in messageQueue {
Comment thread
andrewheard marked this conversation as resolved.
guard let data = encodeClientMessage(message) else { continue }
guard let self = self else { return }
guard let data = await self.encodeClientMessage(message) else { continue }

do {
try await webSocket.send(.data(data))
try await self.webSocket?.send(.data(data))
} catch {
AILog.error(code: .liveSessionFailedToSendClientMessage, error.localizedDescription)
await self.close(finishingStream: true)
self.streamState.withLock { state in
state.continuation.finish(throwing: error)
state.isFinished = true
}
Comment thread
andrewheard marked this conversation as resolved.
Outdated
}
}
}
Expand Down
9 changes: 8 additions & 1 deletion FirebaseAI/Sources/Types/Public/Live/LiveSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ public final class LiveSession: Sendable {
self.service = service
}

deinit {
let service = self.service
Task {
await service.close()
}
}

/// Response to a ``LiveServerToolCall`` received from the server.
///
/// This method is used both for the realtime API and the incremental API.
Expand Down Expand Up @@ -142,7 +149,7 @@ public final class LiveSession: Sendable {
/// Attempting to receive content from a closed session will cause a
/// ``LiveSessionUnexpectedClosureError`` error to be thrown.
public func close() async {
await service.close()
await service.close(finishingStream: true)
}

/// Resumes an existing live session with the server.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@
CURRENT_PROJECT_VERSION = 1;
DEAD_CODE_STRIPPING = YES;
GENERATE_INFOPLIST_FILE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 15.0;
IPHONEOS_DEPLOYMENT_TARGET = 16.0;
MACOSX_DEPLOYMENT_TARGET = 12.0;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = com.google.firebase.FirebaseAITestAppTests;
Expand All @@ -459,7 +459,7 @@
CURRENT_PROJECT_VERSION = 1;
DEAD_CODE_STRIPPING = YES;
GENERATE_INFOPLIST_FILE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 15.0;
IPHONEOS_DEPLOYMENT_TARGET = 16.0;
MACOSX_DEPLOYMENT_TARGET = 12.0;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = com.google.firebase.FirebaseAITestAppTests;
Expand Down
Loading
Loading