From 284b7804190a46de7c28b834c0be752afa1f997f Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Tue, 17 Dec 2024 20:40:10 +0000 Subject: [PATCH] Correctly override connection state using sslContextCallback Motivation: The sslContextCallback overrides modify the SSL_CTX object itself. That's not useful: in addition to having a global effect, which is rarely what we want, certain modifications don't apply that late, and setting the credentials is one of them. The result is that the callback didn't actually do anything. Not that useful. Modifications: - Extract the cert chain setting functions from NIOSSLContext and move them to general-purpose functions. - Add support there for calling them on a SSL * as well as on a SSL_CTX * - Call them in an override function on SSLConnection - Call that! - Add a bunch of tests. Result: The callback actually does what it should. --- Sources/NIOSSL/SSLConnection.swift | 12 ++ Sources/NIOSSL/SSLContext.swift | 110 +-------------- Sources/NIOSSL/UnsafeKeyAndChainTarget.swift | 140 +++++++++++++++++++ Tests/NIOSSLTests/TLSConfigurationTest.swift | 101 +++++++++++++ 4 files changed, 258 insertions(+), 105 deletions(-) create mode 100644 Sources/NIOSSL/UnsafeKeyAndChainTarget.swift diff --git a/Sources/NIOSSL/SSLConnection.swift b/Sources/NIOSSL/SSLConnection.swift index fca85a96..fa44b792 100644 --- a/Sources/NIOSSL/SSLConnection.swift +++ b/Sources/NIOSSL/SSLConnection.swift @@ -495,6 +495,18 @@ extension SSLConnection { return try buffers.map { try NIOSSLCertificate(bytes: $0, format: .der) } } } + + func applyOverride(_ changes: NIOSSLContextConfigurationOverride) throws { + let connection = UnsafeKeyAndChainTarget.ssl(self.ssl) + if let chain = changes.certificateChain { + try connection.useCertificateChain(chain) + } + + // Attempt to load the new private key and abort on failure + if let pkey = changes.privateKey { + try connection.usePrivateKeySource(pkey) + } + } } extension SSLConnection.PeerCertificateChainBuffers: RandomAccessCollection { diff --git a/Sources/NIOSSL/SSLContext.swift b/Sources/NIOSSL/SSLContext.swift index d8abc54f..5eef4011 100644 --- a/Sources/NIOSSL/SSLContext.swift +++ b/Sources/NIOSSL/SSLContext.swift @@ -262,14 +262,8 @@ private func sslContextCallback(ssl: OpaquePointer?, arg: UnsafeMutableRawPointe case .success(let changes): do { // Attempt to load the new certificate chain and abort on failure - if let chain = changes.certificateChain { - try NIOSSLContext.useCertificateChain(chain, context: parentSwiftContext.sslContext) - } - - // Attempt to load the new private key and abort on failure - if let pkey = changes.privateKey { - try NIOSSLContext.usePrivateKeySource(pkey, context: parentSwiftContext.sslContext) - } + let ssl = SSLConnection.loadConnectionFromSSL(ssl) + try ssl.applyOverride(changes) // We must return 1 to signal a successful load of the new context return 1 @@ -450,10 +444,11 @@ public final class NIOSSLContext { ) } - try NIOSSLContext.useCertificateChain(configuration.certificateChain, context: context) + let handle = UnsafeKeyAndChainTarget.sslContext(context) + try handle.useCertificateChain(configuration.certificateChain) if let pkey = configuration.privateKey { - try NIOSSLContext.usePrivateKeySource(pkey, context: context) + try handle.usePrivateKeySource(pkey) } if configuration.encodedApplicationProtocols.count > 0 { @@ -579,101 +574,6 @@ extension NIOSSLContext { } extension NIOSSLContext { - fileprivate static func useCertificateChain( - _ certificateChain: [NIOSSLCertificateSource], - context: OpaquePointer - ) throws { - var leaf = true - for source in certificateChain { - switch source { - case .file(let p): - NIOSSLContext.useCertificateChainFile(p, context: context) - leaf = false - case .certificate(let cert): - if leaf { - try NIOSSLContext.setLeafCertificate(cert, context: context) - leaf = false - } else { - try NIOSSLContext.addAdditionalChainCertificate(cert, context: context) - } - } - } - } - - private static func useCertificateChainFile(_ path: String, context: OpaquePointer) { - // TODO(cory): This shouldn't be an assert but should instead be actual error handling. - // assert(path.isFileURL) - let result = path.withCString { (pointer) -> CInt in - CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(context, pointer) - } - - // TODO(cory): again, some error handling would be good. - precondition(result == 1) - } - - private static func setLeafCertificate(_ cert: NIOSSLCertificate, context: OpaquePointer) throws { - let rc = cert.withUnsafeMutableX509Pointer { ref in - CNIOBoringSSL_SSL_CTX_use_certificate(context, ref) - } - guard rc == 1 else { - throw NIOSSLError.failedToLoadCertificate - } - } - - private static func addAdditionalChainCertificate(_ cert: NIOSSLCertificate, context: OpaquePointer) throws { - let rc = cert.withUnsafeMutableX509Pointer { ref in - CNIOBoringSSL_SSL_CTX_add1_chain_cert(context, ref) - } - guard rc == 1 else { - throw NIOSSLError.failedToLoadCertificate - } - } - - fileprivate static func usePrivateKeySource(_ privateKey: NIOSSLPrivateKeySource, context: OpaquePointer) throws { - switch privateKey { - case .file(let p): - try NIOSSLContext.usePrivateKeyFile(p, context: context) - case .privateKey(let key): - try NIOSSLContext.setPrivateKey(key, context: context) - } - } - - private static func setPrivateKey(_ key: NIOSSLPrivateKey, context: OpaquePointer) throws { - switch key.representation { - case .native: - let rc = key.withUnsafeMutableEVPPKEYPointer { ref in - CNIOBoringSSL_SSL_CTX_use_PrivateKey(context, ref) - } - guard 1 == rc else { - throw NIOSSLError.failedToLoadPrivateKey - } - case .custom: - CNIOBoringSSL_SSL_CTX_set_private_key_method(context, customPrivateKeyMethod) - } - } - - private static func usePrivateKeyFile(_ path: String, context: OpaquePointer) throws { - let pathExtension = path.split(separator: ".").last - let fileType: CInt - - switch pathExtension?.lowercased() { - case .some("pem"): - fileType = SSL_FILETYPE_PEM - case .some("der"), .some("key"): - fileType = SSL_FILETYPE_ASN1 - default: - throw NIOSSLExtraError.unknownPrivateKeyFileType(path: path) - } - - let result = path.withCString { (pointer) -> CInt in - CNIOBoringSSL_SSL_CTX_use_PrivateKey_file(context, pointer, fileType) - } - - guard result == 1 else { - throw NIOSSLError.failedToLoadPrivateKey - } - } - private static func setAlpnProtocols(_ protocols: [[UInt8]], context: OpaquePointer) throws { // This copy should be done infrequently, so we don't worry too much about it. let protoBuf = protocols.reduce([UInt8](), +) diff --git a/Sources/NIOSSL/UnsafeKeyAndChainTarget.swift b/Sources/NIOSSL/UnsafeKeyAndChainTarget.swift new file mode 100644 index 00000000..8f212e5f --- /dev/null +++ b/Sources/NIOSSL/UnsafeKeyAndChainTarget.swift @@ -0,0 +1,140 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2024 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +@_implementationOnly import CNIOBoringSSL + +enum UnsafeKeyAndChainTarget { + case sslContext(OpaquePointer) + case ssl(OpaquePointer) + + func useCertificateChain( + _ certificateChain: [NIOSSLCertificateSource] + ) throws { + var leaf = true + for source in certificateChain { + switch source { + case .file(let p): + self.useCertificateChainFile(p) + leaf = false + case .certificate(let cert): + if leaf { + try self.setLeafCertificate(cert) + leaf = false + } else { + try self.addAdditionalChainCertificate(cert) + } + } + } + } + + func useCertificateChainFile(_ path: String) { + let result = path.withCString { (pointer) -> CInt in + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(context, pointer) + case .ssl(let ssl): + CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(ssl, pointer) + } + } + + precondition(result == 1) + } + + func setLeafCertificate(_ cert: NIOSSLCertificate) throws { + let rc = cert.withUnsafeMutableX509Pointer { ref in + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_use_certificate(context, ref) + case .ssl(let ssl): + CNIOBoringSSL_SSL_use_certificate(ssl, ref) + } + } + guard rc == 1 else { + throw NIOSSLError.failedToLoadCertificate + } + } + + func addAdditionalChainCertificate(_ cert: NIOSSLCertificate) throws { + let rc = cert.withUnsafeMutableX509Pointer { ref in + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_add1_chain_cert(context, ref) + case .ssl(let ssl): + CNIOBoringSSL_SSL_add1_chain_cert(ssl, ref) + } + } + guard rc == 1 else { + throw NIOSSLError.failedToLoadCertificate + } + } + + func usePrivateKeySource(_ privateKey: NIOSSLPrivateKeySource) throws { + switch privateKey { + case .file(let p): + try self.usePrivateKeyFile(p) + case .privateKey(let key): + try self.setPrivateKey(key) + } + } + + func setPrivateKey(_ key: NIOSSLPrivateKey) throws { + switch key.representation { + case .native: + let rc = key.withUnsafeMutableEVPPKEYPointer { ref in + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_use_PrivateKey(context, ref) + case .ssl(let ssl): + CNIOBoringSSL_SSL_use_PrivateKey(ssl, ref) + } + } + guard 1 == rc else { + throw NIOSSLError.failedToLoadPrivateKey + } + case .custom: + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_set_private_key_method(context, customPrivateKeyMethod) + case .ssl(let ssl): + CNIOBoringSSL_SSL_set_private_key_method(ssl, customPrivateKeyMethod) + } + } + } + + func usePrivateKeyFile(_ path: String) throws { + let pathExtension = path.split(separator: ".").last + let fileType: CInt + + switch pathExtension?.lowercased() { + case .some("pem"): + fileType = SSL_FILETYPE_PEM + case .some("der"), .some("key"): + fileType = SSL_FILETYPE_ASN1 + default: + throw NIOSSLExtraError.unknownPrivateKeyFileType(path: path) + } + + let result = path.withCString { (pointer) -> CInt in + switch self { + case .sslContext(let context): + CNIOBoringSSL_SSL_CTX_use_PrivateKey_file(context, pointer, fileType) + case .ssl(let ssl): + CNIOBoringSSL_SSL_use_PrivateKey_file(ssl, pointer, fileType) + } + } + + guard result == 1 else { + throw NIOSSLError.failedToLoadPrivateKey + } + } +} \ No newline at end of file diff --git a/Tests/NIOSSLTests/TLSConfigurationTest.swift b/Tests/NIOSSLTests/TLSConfigurationTest.swift index 089ae4fd..f9a0f2b1 100644 --- a/Tests/NIOSSLTests/TLSConfigurationTest.swift +++ b/Tests/NIOSSLTests/TLSConfigurationTest.swift @@ -1907,6 +1907,107 @@ class TLSConfigurationTest: XCTestCase { try assertHandshakeSucceeded(withClientConfig: clientConfig, andServerConfig: serverConfig) waitForExpectations(timeout: 1) } + + func testClientSideCertSelection() throws { + var clientConfig = TLSConfiguration.makeClientConfiguration() + clientConfig.certificateVerification = .noHostnameVerification + clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1]) + clientConfig.sslContextCallback = { _, promise in + var `override` = NIOSSLContextConfigurationOverride() + override.certificateChain = [.certificate(TLSConfigurationTest.cert2)] + override.privateKey = .privateKey(TLSConfigurationTest.key2) + promise.succeed(override) + } + + var serverConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(TLSConfigurationTest.cert1)], + privateKey: .privateKey(TLSConfigurationTest.key1) + ) + serverConfig.certificateVerification = .noHostnameVerification + serverConfig.trustRoots = .certificates([TLSConfigurationTest.cert2]) + + try assertHandshakeSucceeded(withClientConfig: clientConfig, andServerConfig: serverConfig) + } + + func testServerSideCertSelection() throws { + var clientConfig = TLSConfiguration.makeClientConfiguration() + clientConfig.certificateVerification = .noHostnameVerification + clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1]) + + var serverConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(TLSConfigurationTest.cert2)], + privateKey: .privateKey(TLSConfigurationTest.key2) + ) + serverConfig.sslContextCallback = { _, promise in + var `override` = NIOSSLContextConfigurationOverride() + override.certificateChain = [.certificate(TLSConfigurationTest.cert1)] + override.privateKey = .privateKey(TLSConfigurationTest.key1) + promise.succeed(override) + } + + try assertHandshakeSucceeded(withClientConfig: clientConfig, andServerConfig: serverConfig) + } + + func testOverrideWithNothingIsFine() throws { + var clientConfig = TLSConfiguration.makeClientConfiguration() + clientConfig.certificateVerification = .noHostnameVerification + clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1]) + + var serverConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(TLSConfigurationTest.cert1)], + privateKey: .privateKey(TLSConfigurationTest.key1) + ) + serverConfig.sslContextCallback = { _, promise in + let `override` = NIOSSLContextConfigurationOverride() + promise.succeed(override) + } + + try assertHandshakeSucceeded(withClientConfig: clientConfig, andServerConfig: serverConfig) + } + + func testOverrideToInvalidCertFailsHandshake() throws { + var clientConfig = TLSConfiguration.makeClientConfiguration() + clientConfig.certificateVerification = .noHostnameVerification + clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1]) + + var serverConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(TLSConfigurationTest.cert2)], + privateKey: .privateKey(TLSConfigurationTest.key2) + ) + serverConfig.sslContextCallback = { _, promise in + var `override` = NIOSSLContextConfigurationOverride() + override.certificateChain = [.certificate(TLSConfigurationTest.cert1)] + promise.succeed(override) + } + + try assertHandshakeError( + withClientConfig: clientConfig, + andServerConfig: serverConfig, + errorTextContains: "TLSV1_ALERT_INTERNAL_ERROR" + ) + } + + func testOverrideToInvalidKeyFailsHandshake() throws { + var clientConfig = TLSConfiguration.makeClientConfiguration() + clientConfig.certificateVerification = .noHostnameVerification + clientConfig.trustRoots = .certificates([TLSConfigurationTest.cert1]) + + var serverConfig = TLSConfiguration.makeServerConfiguration( + certificateChain: [.certificate(TLSConfigurationTest.cert1)], + privateKey: .privateKey(TLSConfigurationTest.key1) + ) + serverConfig.sslContextCallback = { _, promise in + var `override` = NIOSSLContextConfigurationOverride() + override.privateKey = .privateKey(TLSConfigurationTest.key2) + promise.succeed(override) + } + + try assertHandshakeError( + withClientConfig: clientConfig, + andServerConfig: serverConfig, + errorTextContains: "TLSV1_ALERT_INTERNAL_ERROR" + ) + } } extension EmbeddedChannel {