Skip to content

Commit 011a801

Browse files
authored
Make Mqtt5 client callbacks synchronous (#330)
1 parent 403fa83 commit 011a801

2 files changed

Lines changed: 65 additions & 77 deletions

File tree

Source/AwsCommonRuntimeKit/mqtt/Mqtt5Client.swift

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
/// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
/// SPDX-License-Identifier: Apache-2.0.
33

4-
import AwsCMqtt
54
import AwsCIo
5+
import AwsCMqtt
66
import LibNative
77

88
// MARK: - Callback Data Classes
@@ -98,22 +98,25 @@ public class LifecycleDisconnectData {
9898
// MARK: - Callback typealias definitions
9999

100100
/// Defines signature of the Publish callback
101-
public typealias OnPublishReceived = (PublishReceivedData) async -> Void
101+
public typealias OnPublishReceived = @Sendable (PublishReceivedData) -> Void
102102

103103
/// Defines signature of the Lifecycle Event Stopped callback
104-
public typealias OnLifecycleEventStopped = (LifecycleStoppedData) async -> Void
104+
public typealias OnLifecycleEventStopped = @Sendable (LifecycleStoppedData) -> Void
105105

106106
/// Defines signature of the Lifecycle Event Attempting Connect callback
107-
public typealias OnLifecycleEventAttemptingConnect = (LifecycleAttemptingConnectData) async -> Void
107+
public typealias OnLifecycleEventAttemptingConnect = @Sendable (LifecycleAttemptingConnectData) ->
108+
Void
108109

109110
/// Defines signature of the Lifecycle Event Connection Success callback
110-
public typealias OnLifecycleEventConnectionSuccess = (LifecycleConnectionSuccessData) async -> Void
111+
public typealias OnLifecycleEventConnectionSuccess = @Sendable (LifecycleConnectionSuccessData) ->
112+
Void
111113

112114
/// Defines signature of the Lifecycle Event Connection Failure callback
113-
public typealias OnLifecycleEventConnectionFailure = (LifecycleConnectionFailureData) async -> Void
115+
public typealias OnLifecycleEventConnectionFailure = @Sendable (LifecycleConnectionFailureData) ->
116+
Void
114117

115118
/// Defines signature of the Lifecycle Event Disconnection callback
116-
public typealias OnLifecycleEventDisconnection = (LifecycleDisconnectData) async -> Void
119+
public typealias OnLifecycleEventDisconnection = @Sendable (LifecycleDisconnectData) -> Void
117120

118121
/// Callback for users to invoke upon completion of, presumably asynchronous, OnWebSocketHandshakeIntercept callback's initiated process.
119122
public typealias OnWebSocketHandshakeInterceptComplete = (HTTPRequestBase, Int32) -> Void
@@ -122,7 +125,7 @@ public typealias OnWebSocketHandshakeInterceptComplete = (HTTPRequestBase, Int32
122125
/// such as signing/authorization etc... Returning from this function does not continue the websocket
123126
/// handshake since some work flows may be asynchronous. To accommodate that, onComplete must be invoked upon
124127
/// completion of the signing process.
125-
public typealias OnWebSocketHandshakeIntercept = (HTTPRequest, @escaping OnWebSocketHandshakeInterceptComplete) async -> Void
128+
public typealias OnWebSocketHandshakeIntercept = @Sendable (HTTPRequest, @escaping OnWebSocketHandshakeInterceptComplete) -> Void
126129

127130
// MARK: - Mqtt5 Client
128131
public class Mqtt5Client {
@@ -236,22 +239,22 @@ public class Mqtt5ClientCore {
236239

237240
try clientOptions.validateConversionToNative()
238241

239-
self.onPublishReceivedCallback = clientOptions.onPublishReceivedFn ?? { (_) in return }
240-
self.onLifecycleEventStoppedCallback = clientOptions.onLifecycleEventStoppedFn ?? { (_) in return}
241-
self.onLifecycleEventAttemptingConnect = clientOptions.onLifecycleEventAttemptingConnectFn ?? { (_) in return}
242-
self.onLifecycleEventConnectionSuccess = clientOptions.onLifecycleEventConnectionSuccessFn ?? { (_) in return}
243-
self.onLifecycleEventConnectionFailure = clientOptions.onLifecycleEventConnectionFailureFn ?? { (_) in return}
244-
self.onLifecycleEventDisconnection = clientOptions.onLifecycleEventDisconnectionFn ?? { (_) in return}
242+
self.onPublishReceivedCallback = clientOptions.onPublishReceivedFn ?? { (_) in }
243+
self.onLifecycleEventStoppedCallback = clientOptions.onLifecycleEventStoppedFn ?? { (_) in }
244+
self.onLifecycleEventAttemptingConnect = clientOptions.onLifecycleEventAttemptingConnectFn ?? { (_) in }
245+
self.onLifecycleEventConnectionSuccess = clientOptions.onLifecycleEventConnectionSuccessFn ?? { (_) in }
246+
self.onLifecycleEventConnectionFailure = clientOptions.onLifecycleEventConnectionFailureFn ?? { (_) in }
247+
self.onLifecycleEventDisconnection = clientOptions.onLifecycleEventDisconnectionFn ?? { (_) in }
245248
self.onWebsocketInterceptor = clientOptions.onWebsocketTransform
246249

247250
guard let rawValue = (clientOptions.withCPointer(
248251
userData: Unmanaged<Mqtt5ClientCore>.passRetained(self).toOpaque()) { optionsPointer in
249-
return aws_mqtt5_client_new(allocator.rawValue, optionsPointer)
250-
}) else {
251-
// failed to create client, release the callback core
252-
Unmanaged<Mqtt5ClientCore>.passUnretained(self).release()
253-
throw CommonRunTimeError.crtError(.makeFromLastError())
254-
}
252+
return aws_mqtt5_client_new(allocator.rawValue, optionsPointer)
253+
}) else {
254+
// failed to create client, release the callback core
255+
Unmanaged<Mqtt5ClientCore>.passUnretained(self).release()
256+
throw CommonRunTimeError.crtError(.makeFromLastError())
257+
}
255258
self.rawValue = rawValue
256259
}
257260

@@ -443,9 +446,8 @@ internal func MqttClientHandleLifecycleEvent(_ lifecycleEvent: UnsafePointer<aws
443446
case AWS_MQTT5_CLET_ATTEMPTING_CONNECT:
444447

445448
let lifecycleAttemptingConnectData = LifecycleAttemptingConnectData()
446-
Task {
447-
await clientCore.onLifecycleEventAttemptingConnect(lifecycleAttemptingConnectData)
448-
}
449+
clientCore.onLifecycleEventAttemptingConnect(lifecycleAttemptingConnectData)
450+
449451
case AWS_MQTT5_CLET_CONNECTION_SUCCESS:
450452

451453
guard let connackView = lifecycleEvent.pointee.connack_data else {
@@ -460,9 +462,8 @@ internal func MqttClientHandleLifecycleEvent(_ lifecycleEvent: UnsafePointer<aws
460462
let lifecycleConnectionSuccessData = LifecycleConnectionSuccessData(
461463
connackPacket: connackPacket,
462464
negotiatedSettings: NegotiatedSettings(negotiatedSettings))
463-
Task {
464-
await clientCore.onLifecycleEventConnectionSuccess(lifecycleConnectionSuccessData)
465-
}
465+
clientCore.onLifecycleEventConnectionSuccess(lifecycleConnectionSuccessData)
466+
466467
case AWS_MQTT5_CLET_CONNECTION_FAILURE:
467468

468469
var connackPacket: ConnackPacket?
@@ -473,9 +474,8 @@ internal func MqttClientHandleLifecycleEvent(_ lifecycleEvent: UnsafePointer<aws
473474
let lifecycleConnectionFailureData = LifecycleConnectionFailureData(
474475
crtError: crtError,
475476
connackPacket: connackPacket)
476-
Task {
477-
await clientCore.onLifecycleEventConnectionFailure(lifecycleConnectionFailureData)
478-
}
477+
clientCore.onLifecycleEventConnectionFailure(lifecycleConnectionFailureData)
478+
479479
case AWS_MQTT5_CLET_DISCONNECTION:
480480

481481
var disconnectPacket: DisconnectPacket?
@@ -487,13 +487,11 @@ internal func MqttClientHandleLifecycleEvent(_ lifecycleEvent: UnsafePointer<aws
487487
let lifecycleDisconnectData = LifecycleDisconnectData(
488488
crtError: crtError,
489489
disconnectPacket: disconnectPacket)
490-
Task {
491-
await clientCore.onLifecycleEventDisconnection(lifecycleDisconnectData)
492-
}
490+
clientCore.onLifecycleEventDisconnection(lifecycleDisconnectData)
491+
493492
case AWS_MQTT5_CLET_STOPPED:
494-
Task {
495-
await clientCore.onLifecycleEventStoppedCallback(LifecycleStoppedData())
496-
}
493+
clientCore.onLifecycleEventStoppedCallback(LifecycleStoppedData())
494+
497495
default:
498496
fatalError("A lifecycle event with an invalid event type was encountered.")
499497
}
@@ -511,9 +509,7 @@ internal func MqttClientHandlePublishRecieved(
511509
if let publish {
512510
let publishPacket = PublishPacket(publish)
513511
let publishReceivedData = PublishReceivedData(publishPacket: publishPacket)
514-
Task {
515-
await clientCore.onPublishReceivedCallback(publishReceivedData)
516-
}
512+
clientCore.onPublishReceivedCallback(publishReceivedData)
517513
} else {
518514
fatalError("MqttClientHandlePublishRecieved called with null publish")
519515
}
@@ -541,9 +537,7 @@ internal func MqttClientWebsocketTransform(
541537
}
542538

543539
if clientCore.onWebsocketInterceptor != nil {
544-
Task {
545-
await clientCore.onWebsocketInterceptor!(httpRequest, signerTransform)
546-
}
540+
clientCore.onWebsocketInterceptor!(httpRequest, signerTransform)
547541
}
548542
}
549543
}

Test/AwsCommonRuntimeKitTests/mqtt/Mqtt5ClientTests.swift

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -648,17 +648,15 @@ class Mqtt5ClientTests: XCBaseTestCase {
648648

649649
// We manually setup the websocket transform to avoid recursive reference between provider and test context
650650
let onWebsocketTransform : OnWebSocketHandshakeIntercept = { httpRequest, completCallback in
651-
do
652-
{
653-
let returnedHttpRequest = try await Signer.signRequest(request: httpRequest, config:signingConfig)
654-
completCallback(returnedHttpRequest, AWS_OP_SUCCESS)
655-
}
656-
catch CommonRunTimeError.crtError (let error) {
657-
completCallback(httpRequest, Int32(error.code))
658-
}
659-
catch
660-
{
661-
completCallback(httpRequest, Int32(AWS_ERROR_UNSUPPORTED_OPERATION.rawValue))
651+
Task {
652+
do {
653+
let returnedHttpRequest = try await Signer.signRequest(request: httpRequest, config:signingConfig)
654+
completCallback(returnedHttpRequest, AWS_OP_SUCCESS)
655+
} catch CommonRunTimeError.crtError (let error) {
656+
completCallback(httpRequest, Int32(error.code))
657+
} catch {
658+
completCallback(httpRequest, Int32(AWS_ERROR_UNSUPPORTED_OPERATION.rawValue))
659+
}
662660
}
663661
}
664662

@@ -730,17 +728,15 @@ class Mqtt5ClientTests: XCBaseTestCase {
730728

731729
// We manually setup the websocket transform to avoid recursive reference between provider and test context
732730
let onWebsocketTransform : OnWebSocketHandshakeIntercept = { httpRequest, completCallback in
733-
do
734-
{
735-
let returnedHttpRequest = try await Signer.signRequest(request: httpRequest, config:signingConfig)
736-
completCallback(returnedHttpRequest, AWS_OP_SUCCESS)
737-
}
738-
catch CommonRunTimeError.crtError (let error) {
739-
completCallback(httpRequest, Int32(error.code))
740-
}
741-
catch
742-
{
743-
completCallback(httpRequest, Int32(AWS_ERROR_UNSUPPORTED_OPERATION.rawValue))
731+
Task {
732+
do {
733+
let returnedHttpRequest = try await Signer.signRequest(request: httpRequest, config:signingConfig)
734+
completCallback(returnedHttpRequest, AWS_OP_SUCCESS)
735+
} catch CommonRunTimeError.crtError (let error) {
736+
completCallback(httpRequest, Int32(error.code))
737+
} catch {
738+
completCallback(httpRequest, Int32(AWS_ERROR_UNSUPPORTED_OPERATION.rawValue))
739+
}
744740
}
745741
}
746742

@@ -850,20 +846,18 @@ class Mqtt5ClientTests: XCBaseTestCase {
850846
omitSessionToken: true)
851847

852848
let onWebsocketTransform : OnWebSocketHandshakeIntercept = { httpRequest, completCallback in
853-
do
854-
{
855-
let returnedHttpRequest = try await Signer.signRequest(request: httpRequest, config:signingConfig)
856-
completCallback(returnedHttpRequest, AWS_OP_SUCCESS)
857-
print("complete signing")
858-
}
859-
catch CommonRunTimeError.crtError (let error) {
860-
completCallback(httpRequest, Int32(error.code))
861-
print("signing failed with crterror")
862-
}
863-
catch
864-
{
865-
completCallback(httpRequest, Int32(AWS_ERROR_UNSUPPORTED_OPERATION.rawValue))
866-
print("signing failed")
849+
Task {
850+
do {
851+
let returnedHttpRequest = try await Signer.signRequest(request: httpRequest, config:signingConfig)
852+
completCallback(returnedHttpRequest, AWS_OP_SUCCESS)
853+
print("complete signing")
854+
} catch CommonRunTimeError.crtError (let error) {
855+
completCallback(httpRequest, Int32(error.code))
856+
print("signing failed with crterror")
857+
} catch {
858+
completCallback(httpRequest, Int32(AWS_ERROR_UNSUPPORTED_OPERATION.rawValue))
859+
print("signing failed")
860+
}
867861
}
868862
}
869863

0 commit comments

Comments
 (0)