Skip to content

Commit

Permalink
fix: mqtt connection bugs (#55)
Browse files Browse the repository at this point in the history
* make closures public

* fix connect

* bug fix

* adjust callback code

* clean up pointers

* fix c pointer issue

* fix buld

* fix param order

* fix connection callback handlers

* fix initializer

* lint

* changes

* linter
  • Loading branch information
kneekey23 authored Nov 2, 2021
1 parent 5f8a6b9 commit faff91d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 61 deletions.
2 changes: 1 addition & 1 deletion Source/AwsCommonRuntimeKit/Utilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func allocatePointer<T>(_ capacity: Int = 1) -> UnsafeMutablePointer<T> {
func toPointerArray<T, P: PointerConformance>(_ array: [T]) -> P {
let pointers = UnsafeMutablePointer<T>.allocate(capacity: array.count)

for index in 0...array.count {
for index in 0..<array.count {
pointers.advanced(by: index).initialize(to: array[index])
}
return P(OpaquePointer(pointers))
Expand Down
96 changes: 36 additions & 60 deletions Source/AwsCommonRuntimeKit/mqtt/MqttConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ public typealias OnWebSocketHandshakeInterceptComplete = (HttpRequest, CRTError)

//swiftlint:disable cyclomatic_complexity file_length type_body_length opening_brace
public class MqttConnection {
var onConnectionInterrupted: OnConnectionInterrupted = {(connectionPtr, errorCode) in }
var onConnectionResumed: OnConnectionResumed = {(connectionPtr, returnCode, retain) in }
var onDisconnect: OnDisconnect = {(connectionPtr) in }
var onConnectionComplete: OnConnectionComplete = {(connectionPtr, errorCode, returnCode, retain) in}
var onWebSocketHandshakeIntercept: OnWebSocketHandshakeIntercept?
var onWebSocketHandshakeInterceptComplete: OnWebSocketHandshakeInterceptComplete?
public var onConnectionInterrupted: OnConnectionInterrupted = {(connectionPtr, errorCode) in }
public var onConnectionResumed: OnConnectionResumed = {(connectionPtr, returnCode, retain) in }
public var onDisconnect: OnDisconnect = {(connectionPtr) in }
public var onConnectionComplete: OnConnectionComplete = {(connectionPtr, errorCode, returnCode, retain) in}
public var onWebSocketHandshakeIntercept: OnWebSocketHandshakeIntercept?
public var onWebSocketHandshakeInterceptComplete: OnWebSocketHandshakeInterceptComplete?

private var allocator: Allocator
private var clientPointer: UnsafeMutablePointer<aws_mqtt_client>
Expand All @@ -37,6 +37,8 @@ public class MqttConnection {
let tlsContext: TlsContext?
var proxyOptions: HttpProxyOptions?
var pubCallbackData: UnsafeMutablePointer<PubCallbackData>?
/// This pointer has to live for the duration of all the callbacks that could be called so that is why we store it on the connection itself.
var nativePointer: UnsafeMutableRawPointer?

init(clientPointer: UnsafeMutablePointer<aws_mqtt_client>,
host: String,
Expand All @@ -58,35 +60,31 @@ public class MqttConnection {
}

private func setUpCallbackData() {

self.nativePointer = fromPointer(ptr: self)
aws_mqtt_client_connection_set_connection_interruption_handlers(rawValue, { (_, errorCode, userData) in
guard let userData = userData else {
return
}

let pointer = userData.assumingMemoryBound(to: MqttConnection.self)

defer { pointer.deinitializeAndDeallocate()}

let error = AWSError(errorCode: errorCode)

pointer.pointee.onConnectionInterrupted(pointer.pointee.rawValue,
CRTError.crtError(error))

}, rawValue, { (_, connectReturnCode, sessionPresent, userData) in
}, nativePointer, { (_, connectReturnCode, sessionPresent, userData) in
guard let userData = userData else {
return
}

let pointer = userData.assumingMemoryBound(to: MqttConnection.self)

defer { pointer.deinitializeAndDeallocate()}

pointer.pointee.onConnectionResumed(pointer.pointee.rawValue,
MqttReturnCode(rawValue: connectReturnCode),
sessionPresent)

}, rawValue)
}, nativePointer)
}

/// Sets the will message to send with the CONNECT packet.
Expand All @@ -97,14 +95,13 @@ public class MqttConnection {
/// - payload: The payload of the will message to send over as `Data`
/// - Returns: A `Bool` if will was set successfully
public func setWill(topic: String, qos: MqttQos, retain: Bool, payload: Data) -> Bool {
let topicPointer: UnsafeMutablePointer<aws_byte_cursor> = fromPointer(ptr: topic.awsByteCursor)
let payloadPointer: UnsafeMutablePointer<aws_byte_cursor> = fromPointer(ptr: payload.awsByteCursor)

var topicByteCursor = topic.awsByteCursor
var payloadByteCursor = payload.awsByteCursor
return aws_mqtt_client_connection_set_will(rawValue,
topicPointer,
&topicByteCursor,
qos.rawValue,
retain,
payloadPointer) == AWS_OP_SUCCESS
&payloadByteCursor) == AWS_OP_SUCCESS
}

/// Sets the username and/or password to send with the CONNECT packet.
Expand All @@ -113,11 +110,11 @@ public class MqttConnection {
/// - password: Password to authenticate with as `String`
/// - Returns: A `Bool` if login was set successfully.
public func setLogin(username: String, password: String) -> Bool {
let usernamePointer: UnsafeMutablePointer<aws_byte_cursor> = fromPointer(ptr: username.awsByteCursor)
let passwordPointer: UnsafeMutablePointer<aws_byte_cursor> = fromPointer(ptr: password.awsByteCursor)
var usernameByteCursor = username.awsByteCursor
var passwordByteCursor = password.awsByteCursor
return aws_mqtt_client_connection_set_login(rawValue,
usernamePointer,
passwordPointer) == AWS_OP_SUCCESS
&usernameByteCursor,
&passwordByteCursor) == AWS_OP_SUCCESS
}

/// Opens the actual connection defined this class. Once the connection is opened, `OnConnectionComplete`
Expand Down Expand Up @@ -152,7 +149,7 @@ public class MqttConnection {
mqttOptions.keep_alive_time_secs = UInt16(keepAliveTime)
mqttOptions.ping_timeout_ms = UInt32(requestTimeoutMs)
mqttOptions.clean_session = cleanSession
mqttOptions.user_data = fromPointer(ptr: self)
mqttOptions.user_data = nativePointer

mqttOptions.on_connection_complete = { (connectionPtr,
errorCode,
Expand All @@ -178,7 +175,7 @@ public class MqttConnection {
if useWebSockets {
if onWebSocketHandshakeIntercept != nil {

if aws_mqtt_client_connection_use_websockets(rawValue,
aws_mqtt_client_connection_use_websockets(rawValue,
{ (httpRequest, userData, completeFn, completeUserData) in
guard let userData = userData,
let httpRequest = httpRequest else {
Expand All @@ -195,13 +192,9 @@ public class MqttConnection {
//can unwrap here with ! because we know its not nil at this point
ptr.pointee.onWebSocketHandshakeIntercept!(HttpRequest(message: httpRequest),
onInterceptComplete)
}, rawValue, nil, nil) == AWS_OP_SUCCESS {
return false
}
}, rawValue, nil, nil)
} else {
if aws_mqtt_client_connection_use_websockets(rawValue, nil, nil, nil, nil) == AWS_OP_SUCCESS {
return false
}
aws_mqtt_client_connection_use_websockets(rawValue, nil, nil, nil, nil)
}

if let proxyOptions = proxyOptions {
Expand All @@ -221,7 +214,7 @@ public class MqttConnection {
pOptions.host = proxyOptions.hostName.awsByteCursor
pOptions.port = proxyOptions.port

if aws_mqtt_client_connection_set_http_proxy_options(rawValue, &pOptions) == AWS_OP_SUCCESS {
if aws_mqtt_client_connection_set_http_proxy_options(rawValue, &pOptions) != AWS_OP_SUCCESS {
return false
}
}
Expand Down Expand Up @@ -290,10 +283,9 @@ public class MqttConnection {
let pubCallbackPtr: UnsafeMutablePointer<PubCallbackData> = fromPointer(ptr: pubCallbackData)
let subAckCallbackData = SubAckCallbackData(onSubAck: onSubAck, connection: self, topic: nil)
let subAckCallbackPtr: UnsafeMutablePointer<SubAckCallbackData> = fromPointer(ptr: subAckCallbackData)
let topicPtr: UnsafeMutablePointer<aws_byte_cursor> = fromPointer(ptr: topicFilter.awsByteCursor)

var topicByteCursor = topicFilter.awsByteCursor
let packetId = aws_mqtt_client_connection_subscribe(rawValue,
topicPtr,
&topicByteCursor,
qos.rawValue,
{ (_, topicPtr, payload, _, _, _, userData) in
guard let userData = userData,
Expand All @@ -302,20 +294,14 @@ public class MqttConnection {
return
}
let ptr = userData.assumingMemoryBound(to: PubCallbackData.self)
defer {
ptr.deinitializeAndDeallocate()
topicPtr?.deallocate()
}

ptr.pointee.onPublishReceived(ptr.pointee.mqttConnection, topic, payload.pointee.toData())
}, pubCallbackPtr, nil, { (_, packetId, topicPtr, qos, errorCode, userData) in
guard let userData = userData, let topic = topicPtr?.pointee.toString() else {
return
}
let ptr = userData.assumingMemoryBound(to: SubAckCallbackData.self)
defer {
ptr.deinitializeAndDeallocate()
topicPtr?.deallocate()
}

let error = AWSError(errorCode: errorCode)
ptr.pointee.onSubAck(ptr.pointee.connection,
Int16(packetId),
Expand Down Expand Up @@ -355,14 +341,10 @@ public class MqttConnection {
return
}
let ptr = userData.assumingMemoryBound(to: MultiSubAckCallbackData.self)
defer {ptr.deinitializeAndDeallocate()}
var topics = [String]()
for index in 0...topicPointers.pointee.current_size {
let pointer = topicPointers.pointee.data.advanced(by: index)
let swiftString = pointer.assumingMemoryBound(to: String.self)
defer {
pointer.deallocate()
}
topics.append(swiftString.pointee)
}
let error = AWSError(errorCode: errorCode)
Expand All @@ -383,17 +365,15 @@ public class MqttConnection {
public func unsubscribe(topicFilter: String, onComplete: @escaping OnOperationComplete) -> UInt16 {
let opCallbackData = OpCompleteCallbackData(connection: self, onOperationComplete: onComplete)
let opCallbackPtr: UnsafeMutablePointer<OpCompleteCallbackData> = fromPointer(ptr: opCallbackData)
let topicPtr: UnsafeMutablePointer<aws_byte_cursor> = fromPointer(ptr: topicFilter.awsByteCursor)
var topicByteCursor = topicFilter.awsByteCursor
let packetId = aws_mqtt_client_connection_unsubscribe(rawValue,
topicPtr,
&topicByteCursor,
{ (_, packetId, errorCode, userData) in
guard let userData = userData else {
return
}
let ptr = userData.assumingMemoryBound(to: OpCompleteCallbackData.self)
defer {
ptr.deinitializeAndDeallocate()
}

let error = AWSError(errorCode: errorCode)
ptr.pointee.onOperationComplete(ptr.pointee.connection,
Int16(packetId),
Expand All @@ -420,29 +400,25 @@ public class MqttConnection {
connection: self,
onOperationComplete: onComplete)
let opCallbackPtr: UnsafeMutablePointer<OpCompleteCallbackData> = fromPointer(ptr: opCallbackData)
let payloadPointer: UnsafeMutablePointer<aws_byte_cursor> = fromPointer(ptr: payload.awsByteCursor)
let topicPointer: UnsafeMutablePointer<aws_byte_cursor> = fromPointer(ptr: topic.awsByteCursor)

var topicByteCursor = topic.awsByteCursor
var payloadByteCursor = payload.awsByteCursor
let packetId = aws_mqtt_client_connection_publish(rawValue,
payloadPointer,
&topicByteCursor,
qos.rawValue,
retain,
topicPointer,
&payloadByteCursor,
{ (_, packetId, errorCode, userData) in
guard let userData = userData else {
return
}
let ptr = userData.assumingMemoryBound(to: OpCompleteCallbackData.self)
defer { ptr.deinitializeAndDeallocate()}

let error = AWSError(errorCode: errorCode)
ptr.pointee.onOperationComplete(ptr.pointee.connection,
Int16(packetId),
CRTError.crtError(error))
}, opCallbackPtr)

topicPointer.deinitializeAndDeallocate()
payloadPointer.deinitializeAndDeallocate()

return packetId
}

Expand Down

0 comments on commit faff91d

Please sign in to comment.