Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions Sources/Engine/NativeEngine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ import Foundation
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
public class NativeEngine: NSObject, Engine, URLSessionDataDelegate, URLSessionWebSocketDelegate {
private var task: URLSessionWebSocketTask?
private var clientCredential: URLCredential?
weak var delegate: EngineDelegate?

public init(clientCredential: URLCredential? = nil) {
self.clientCredential = clientCredential
}

public func register(delegate: EngineDelegate) {
self.delegate = delegate
}
Expand Down Expand Up @@ -93,4 +98,17 @@ public class NativeEngine: NSObject, Engine, URLSessionDataDelegate, URLSessionW
}
broadcast(event: .disconnected(r, UInt16(closeCode.rawValue)))
}

public func urlSession(_ session: URLSession, didReceive challenge: URLAuthenticationChallenge, completionHandler: @escaping (URLSession.AuthChallengeDisposition, URLCredential?) -> Void) {
var credential: URLCredential? = nil
var disposition: URLSession.AuthChallengeDisposition = .performDefaultHandling

let authMethod = challenge.protectionSpace.authenticationMethod
if authMethod == NSURLAuthenticationMethodClientCertificate && self.clientCredential != nil {
credential = self.clientCredential
disposition = .useCredential
}

completionHandler(disposition, credential)
}
}
5 changes: 4 additions & 1 deletion Sources/Engine/WSEngine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate {
private let httpHandler: HTTPHandler
private let compressionHandler: CompressionHandler?
private let certPinner: CertificatePinning?
private let clientCredential: URLCredential?
private let headerChecker: HeaderValidator
private var request: URLRequest!

Expand All @@ -30,6 +31,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate {

public init(transport: Transport,
certPinner: CertificatePinning? = nil,
clientCredential: URLCredential? = nil,
headerValidator: HeaderValidator = FoundationSecurity(),
httpHandler: HTTPHandler = FoundationHTTPHandler(),
framer: Framer = WSFramer(),
Expand All @@ -38,6 +40,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate {
self.framer = framer
self.httpHandler = httpHandler
self.certPinner = certPinner
self.clientCredential = clientCredential
self.headerChecker = headerValidator
self.compressionHandler = compressionHandler
framer.updateCompression(supports: compressionHandler != nil)
Expand All @@ -64,7 +67,7 @@ FrameCollectorDelegate, HTTPHandlerDelegate {
guard let url = request.url else {
return
}
transport.connect(url: url, timeout: request.timeoutInterval, certificatePinning: certPinner)
transport.connect(url: url, timeout: request.timeoutInterval, certificatePinning: certPinner, clientCredential: clientCredential)
}

public func stop(closeCode: UInt16 = CloseCode.normal.rawValue) {
Expand Down
3 changes: 2 additions & 1 deletion Sources/Framer/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ public struct HTTPWSHeader {
let val = "permessage-deflate; client_max_window_bits; server_max_window_bits=15"
req.setValue(val, forHTTPHeaderField: HTTPWSHeader.extensionName)
}
let hostValue = req.allHTTPHeaderFields?[HTTPWSHeader.hostName] ?? "\(parts.host):\(parts.port)"
let hostname = request.url?.port != nil ? "\(parts.host):\(parts.port)" : parts.host
let hostValue = req.allHTTPHeaderFields?[HTTPWSHeader.hostName] ?? hostname
req.setValue(hostValue, forHTTPHeaderField: HTTPWSHeader.hostName)
return req
}
Expand Down
8 changes: 4 additions & 4 deletions Sources/Starscream/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ open class WebSocket: WebSocketClient, EngineDelegate {
self.engine = engine
}

public convenience init(request: URLRequest, certPinner: CertificatePinning? = FoundationSecurity(), compressionHandler: CompressionHandler? = nil, useCustomEngine: Bool = true) {
public convenience init(request: URLRequest, certPinner: CertificatePinning? = FoundationSecurity(), clientCredential: URLCredential? = nil, compressionHandler: CompressionHandler? = nil, useCustomEngine: Bool = true) {
if #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *), !useCustomEngine {
self.init(request: request, engine: NativeEngine())
self.init(request: request, engine: NativeEngine(clientCredential: clientCredential))
} else if #available(macOS 10.14, iOS 12.0, watchOS 5.0, tvOS 12.0, *) {
self.init(request: request, engine: WSEngine(transport: TCPTransport(), certPinner: certPinner, compressionHandler: compressionHandler))
self.init(request: request, engine: WSEngine(transport: TCPTransport(), certPinner: certPinner, clientCredential: clientCredential, compressionHandler: compressionHandler))
} else {
self.init(request: request, engine: WSEngine(transport: FoundationTransport(), certPinner: certPinner, compressionHandler: compressionHandler))
self.init(request: request, engine: WSEngine(transport: FoundationTransport(), certPinner: certPinner, clientCredential: clientCredential, compressionHandler: compressionHandler))
}
}

Expand Down
10 changes: 9 additions & 1 deletion Sources/Transport/FoundationTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class FoundationTransport: NSObject, Transport, StreamDelegate {
outputStream?.delegate = nil
}

public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil) {
public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil, clientCredential: URLCredential? = nil) {
guard let parts = url.getParts() else {
delegate?.connectionChanged(state: .failed(FoundationTransportError.invalidRequest))
return
Expand All @@ -75,6 +75,14 @@ public class FoundationTransport: NSObject, Transport, StreamDelegate {
let key = CFStreamPropertyKey(rawValue: kCFStreamPropertySocketSecurityLevel)
CFReadStreamSetProperty(inStream, key, kCFStreamSocketSecurityLevelNegotiatedSSL)
CFWriteStreamSetProperty(outStream, key, kCFStreamSocketSecurityLevelNegotiatedSSL)

if let clientCredential = clientCredential {
let certificates = [clientCredential.identity] + clientCredential.certificates
let sslSettings = [kCFStreamSSLCertificates: certificates] as CFDictionary
let sslSettingsKey = CFStreamPropertyKey(rawValue: kCFStreamPropertySSLSettings)
CFReadStreamSetProperty(inStream, sslSettingsKey, sslSettings)
CFWriteStreamSetProperty(outStream, sslSettingsKey, sslSettings)
}
}

onConnect?(inStream, outStream)
Expand Down
8 changes: 7 additions & 1 deletion Sources/Transport/TCPTransport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class TCPTransport: Transport {
//normal connection, will use the "connect" method below
}

public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil) {
public func connect(url: URL, timeout: Double = 10, certificatePinning: CertificatePinning? = nil, clientCredential: URLCredential? = nil) {
guard let parts = url.getParts() else {
delegate?.connectionChanged(state: .failed(TCPTransportError.invalidRequest))
return
Expand All @@ -75,6 +75,12 @@ public class TCPTransport: Transport {
}
})
}, queue)

if let clientCredential = clientCredential {
sec_protocol_options_set_challenge_block(tlsOpts.securityProtocolOptions, { (_, completionHandler) in
completionHandler(sec_identity_create(clientCredential.identity!)!)
}, queue)
}
}
let parameters = NWParameters(tls: tlsOptions, tcp: options)
let conn = NWConnection(host: NWEndpoint.Host.name(parts.host, nil), port: NWEndpoint.Port(rawValue: UInt16(parts.port))!, using: parameters)
Expand Down
2 changes: 1 addition & 1 deletion Sources/Transport/Transport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public protocol TransportEventClient: class {

public protocol Transport: class {
func register(delegate: TransportEventClient)
func connect(url: URL, timeout: Double, certificatePinning: CertificatePinning?)
func connect(url: URL, timeout: Double, certificatePinning: CertificatePinning?, clientCredential: URLCredential?)
func disconnect()
func write(data: Data, completion: @escaping ((Error?) -> ()))
var usingTLS: Bool { get }
Expand Down