diff --git a/Sources/NIOSSL/SSLConnection.swift b/Sources/NIOSSL/SSLConnection.swift index fca85a96..fa44b792 100644 --- a/Sources/NIOSSL/SSLConnection.swift +++ b/Sources/NIOSSL/SSLConnection.swift @@ -495,6 +495,18 @@ extension SSLConnection { return try buffers.map { try NIOSSLCertificate(bytes: $0, format: .der) } } } + + func applyOverride(_ changes: NIOSSLContextConfigurationOverride) throws { + let connection = UnsafeKeyAndChainTarget.ssl(self.ssl) + if let chain = changes.certificateChain { + try connection.useCertificateChain(chain) + } + + // Attempt to load the new private key and abort on failure + if let pkey = changes.privateKey { + try connection.usePrivateKeySource(pkey) + } + } } extension SSLConnection.PeerCertificateChainBuffers: RandomAccessCollection { diff --git a/Sources/NIOSSL/SSLContext.swift b/Sources/NIOSSL/SSLContext.swift index d8abc54f..5eef4011 100644 --- a/Sources/NIOSSL/SSLContext.swift +++ b/Sources/NIOSSL/SSLContext.swift @@ -262,14 +262,8 @@ private func sslContextCallback(ssl: OpaquePointer?, arg: UnsafeMutableRawPointe case .success(let changes): do { // Attempt to load the new certificate chain and abort on failure - if let chain = changes.certificateChain { - try NIOSSLContext.useCertificateChain(chain, context: parentSwiftContext.sslContext) - } - - // Attempt to load the new private key and abort on failure - if let pkey = changes.privateKey { - try NIOSSLContext.usePrivateKeySource(pkey, context: parentSwiftContext.sslContext) - } + let ssl = SSLConnection.loadConnectionFromSSL(ssl) + try ssl.applyOverride(changes) // We must return 1 to signal a successful load of the new context return 1 @@ -450,10 +444,11 @@ public final class NIOSSLContext { ) } - try NIOSSLContext.useCertificateChain(configuration.certificateChain, context: context) + let handle = UnsafeKeyAndChainTarget.sslContext(context) + try handle.useCertificateChain(configuration.certificateChain) if let pkey = configuration.privateKey { - try NIOSSLContext.usePrivateKeySource(pkey, context: context) + try handle.usePrivateKeySource(pkey) } if configuration.encodedApplicationProtocols.count > 0 { @@ -579,101 +574,6 @@ extension NIOSSLContext { } extension NIOSSLContext { - fileprivate static func useCertificateChain( - _ certificateChain: [NIOSSLCertificateSource], - context: OpaquePointer - ) throws { - var leaf = true - for source in certificateChain { - switch source { - case .file(let p): - NIOSSLContext.useCertificateChainFile(p, context: context) - leaf = false - case .certificate(let cert): - if leaf { - try NIOSSLContext.setLeafCertificate(cert, context: context) - leaf = false - } else { - try NIOSSLContext.addAdditionalChainCertificate(cert, context: context) - } - } - } - } - - private static func useCertificateChainFile(_ path: String, context: OpaquePointer) { - // TODO(cory): This shouldn't be an assert but should instead be actual error handling. - // assert(path.isFileURL) - let result = path.withCString { (pointer) -> CInt in - CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(context, pointer) - } - - // TODO(cory): again, some error handling would be good. - precondition(result == 1) - } - - private static func setLeafCertificate(_ cert: NIOSSLCertificate, context: OpaquePointer) throws { - let rc = cert.withUnsafeMutableX509Pointer { ref in - CNIOBoringSSL_SSL_CTX_use_certificate(context, ref) - } - guard rc == 1 else { - throw NIOSSLError.failedToLoadCertificate - } - } - - private static func addAdditionalChainCertificate(_ cert: NIOSSLCertificate, context: OpaquePointer) throws { - let rc = cert.withUnsafeMutableX509Pointer { ref in - CNIOBoringSSL_SSL_CTX_add1_chain_cert(context, ref) - } - guard rc == 1 else { - throw NIOSSLError.failedToLoadCertificate - } - } - - fileprivate static func usePrivateKeySource(_ privateKey: NIOSSLPrivateKeySource, context: OpaquePointer) throws { - switch privateKey { - case .file(let p): - try NIOSSLContext.usePrivateKeyFile(p, context: context) - case .privateKey(let key): - try NIOSSLContext.setPrivateKey(key, context: context) - } - } - - private static func setPrivateKey(_ key: NIOSSLPrivateKey, context: OpaquePointer) throws { - switch key.representation { - case .native: - let rc = key.withUnsafeMutableEVPPKEYPointer { ref in - CNIOBoringSSL_SSL_CTX_use_PrivateKey(context, ref) - } - guard 1 == rc else { - throw NIOSSLError.failedToLoadPrivateKey - } - case .custom: - CNIOBoringSSL_SSL_CTX_set_private_key_method(context, customPrivateKeyMethod) - } - } - - private static func usePrivateKeyFile(_ path: String, context: OpaquePointer) throws { - let pathExtension = path.split(separator: ".").last - let fileType: CInt - - switch pathExtension?.lowercased() { - case .some("pem"): - fileType = SSL_FILETYPE_PEM - case .some("der"), .some("key"): - fileType = SSL_FILETYPE_ASN1 - default: - throw NIOSSLExtraError.unknownPrivateKeyFileType(path: path) - } - - let result = path.withCString { (pointer) -> CInt in - CNIOBoringSSL_SSL_CTX_use_PrivateKey_file(context, pointer, fileType) - } - - guard result == 1 else { - throw NIOSSLError.failedToLoadPrivateKey - } - } - private static func setAlpnProtocols(_ protocols: [[UInt8]], context: OpaquePointer) throws { // This copy should be done infrequently, so we don't worry too much about it. let protoBuf = protocols.reduce([UInt8](), +) diff --git a/Sources/NIOSSL/UnsafeKeyAndChainTarget.swift b/Sources/NIOSSL/UnsafeKeyAndChainTarget.swift new file mode 100644 index 00000000..8f212e5f --- /dev/null +++ b/Sources/NIOSSL/UnsafeKeyAndChainTarget.swift @@ -0,0 +1,140 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2024 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +@_implementationOnly import CNIOBoringSSL + +enum UnsafeKeyAndChainTarget { + case sslContext(OpaquePointer) + case ssl(OpaquePointer) + + func useCertificateChain( + _ certificateChain: [NIOSSLCertificateSource] + ) throws { + var leaf = true + for source in certificateChain { + switch source { + case .file(let p): + self.useCertificateChainFile(p) + leaf = false + case .certificate(let cert): + if leaf { + try self.setLeafCertificate(cert) + leaf = false + } else { + try self.addAdditionalChainCertificate(cert) + } + } + } + } + + func useCertificateChainFile(_ path: String) { + let result = path.withCString { (pointer) -> CInt in + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(context, pointer) + case .ssl(let ssl): + CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(ssl, pointer) + } + } + + precondition(result == 1) + } + + func setLeafCertificate(_ cert: NIOSSLCertificate) throws { + let rc = cert.withUnsafeMutableX509Pointer { ref in + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_use_certificate(context, ref) + case .ssl(let ssl): + CNIOBoringSSL_SSL_use_certificate(ssl, ref) + } + } + guard rc == 1 else { + throw NIOSSLError.failedToLoadCertificate + } + } + + func addAdditionalChainCertificate(_ cert: NIOSSLCertificate) throws { + let rc = cert.withUnsafeMutableX509Pointer { ref in + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_add1_chain_cert(context, ref) + case .ssl(let ssl): + CNIOBoringSSL_SSL_add1_chain_cert(ssl, ref) + } + } + guard rc == 1 else { + throw NIOSSLError.failedToLoadCertificate + } + } + + func usePrivateKeySource(_ privateKey: NIOSSLPrivateKeySource) throws { + switch privateKey { + case .file(let p): + try self.usePrivateKeyFile(p) + case .privateKey(let key): + try self.setPrivateKey(key) + } + } + + func setPrivateKey(_ key: NIOSSLPrivateKey) throws { + switch key.representation { + case .native: + let rc = key.withUnsafeMutableEVPPKEYPointer { ref in + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_use_PrivateKey(context, ref) + case .ssl(let ssl): + CNIOBoringSSL_SSL_use_PrivateKey(ssl, ref) + } + } + guard 1 == rc else { + throw NIOSSLError.failedToLoadPrivateKey + } + case .custom: + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_set_private_key_method(context, customPrivateKeyMethod) + case .ssl(let ssl): + CNIOBoringSSL_SSL_set_private_key_method(ssl, customPrivateKeyMethod) + } + } + } + + func usePrivateKeyFile(_ path: String) throws { + let pathExtension = path.split(separator: ".").last + let fileType: CInt + + switch pathExtension?.lowercased() { + case .some("pem"): + fileType = SSL_FILETYPE_PEM + case .some("der"), .some("key"): + fileType = SSL_FILETYPE_ASN1 + default: + throw NIOSSLExtraError.unknownPrivateKeyFileType(path: path) + } + + let result = path.withCString { (pointer) -> CInt in + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_use_PrivateKey_file(context, pointer, fileType) + case .ssl(let ssl): + CNIOBoringSSL_SSL_use_PrivateKey_file(ssl, pointer, fileType) + } + } + + guard result == 1 else { + throw NIOSSLError.failedToLoadPrivateKey + } + } +} \ No newline at end of file diff --git a/Tests/NIOSSLTests/TLSConfigurationTest.swift b/Tests/NIOSSLTests/TLSConfigurationTest.swift index 089ae4fd..f9a0f2b1 100644 --- a/Tests/NIOSSLTests/TLSConfigurationTest.swift +++ b/Tests/NIOSSLTests/TLSConfigurationTest.swift @@ -1907,6 +1907,107 @@ class TLSConfigurationTest: XCTestCase { try assertHandshakeSucceeded(withClientConfig: clientConfig, andServerConfig: serverConfig) waitForExpectations(timeout: 1) } + + func testClientSideCertSelection() throws { + var clientConfig = TLSConfiguration.makeClientConfiguration() + clientConfig.certificateVerification = .noHostnameVerification + clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1]) + clientConfig.sslContextCallback = { _, promise in + var `override` = NIOSSLContextConfigurationOverride() + override.certificateChain = [.certificate(TLSConfigurationTest.cert2)] + override.privateKey = .privateKey(TLSConfigurationTest.key2) + promise.succeed(override) + } + + var serverConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(TLSConfigurationTest.cert1)], + privateKey: .privateKey(TLSConfigurationTest.key1) + ) + serverConfig.certificateVerification = .noHostnameVerification + serverConfig.trustRoots = .certificates([TLSConfigurationTest.cert2]) + + try assertHandshakeSucceeded(withClientConfig: clientConfig, andServerConfig: serverConfig) + } + + func testServerSideCertSelection() throws { + var clientConfig = TLSConfiguration.makeClientConfiguration() + clientConfig.certificateVerification = .noHostnameVerification + clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1]) + + var serverConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(TLSConfigurationTest.cert2)], + privateKey: .privateKey(TLSConfigurationTest.key2) + ) + serverConfig.sslContextCallback = { _, promise in + var `override` = NIOSSLContextConfigurationOverride() + override.certificateChain = [.certificate(TLSConfigurationTest.cert1)] + override.privateKey = .privateKey(TLSConfigurationTest.key1) + promise.succeed(override) + } + + try assertHandshakeSucceeded(withClientConfig: clientConfig, andServerConfig: serverConfig) + } + + func testOverrideWithNothingIsFine() throws { + var clientConfig = TLSConfiguration.makeClientConfiguration() + clientConfig.certificateVerification = .noHostnameVerification + clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1]) + + var serverConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(TLSConfigurationTest.cert1)], + privateKey: .privateKey(TLSConfigurationTest.key1) + ) + serverConfig.sslContextCallback = { _, promise in + let `override` = NIOSSLContextConfigurationOverride() + promise.succeed(override) + } + + try assertHandshakeSucceeded(withClientConfig: clientConfig, andServerConfig: serverConfig) + } + + func testOverrideToInvalidCertFailsHandshake() throws { + var clientConfig = TLSConfiguration.makeClientConfiguration() + clientConfig.certificateVerification = .noHostnameVerification + clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1]) + + var serverConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(TLSConfigurationTest.cert2)], + privateKey: .privateKey(TLSConfigurationTest.key2) + ) + serverConfig.sslContextCallback = { _, promise in + var `override` = NIOSSLContextConfigurationOverride() + override.certificateChain = [.certificate(TLSConfigurationTest.cert1)] + promise.succeed(override) + } + + try assertHandshakeError( + withClientConfig: clientConfig, + andServerConfig: serverConfig, + errorTextContains: "TLSV1_ALERT_INTERNAL_ERROR" + ) + } + + func testOverrideToInvalidKeyFailsHandshake() throws { + var clientConfig = TLSConfiguration.makeClientConfiguration() + clientConfig.certificateVerification = .noHostnameVerification + clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1]) + + var serverConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(TLSConfigurationTest.cert1)], + privateKey: .privateKey(TLSConfigurationTest.key1) + ) + serverConfig.sslContextCallback = { _, promise in + var `override` = NIOSSLContextConfigurationOverride() + override.privateKey = .privateKey(TLSConfigurationTest.key2) + promise.succeed(override) + } + + try assertHandshakeError( + withClientConfig: clientConfig, + andServerConfig: serverConfig, + errorTextContains: "TLSV1_ALERT_INTERNAL_ERROR" + ) + } } extension EmbeddedChannel {