Skip to content

Commit

Permalink
Correctly override connection state using sslContextCallback
Browse files Browse the repository at this point in the history
Motivation:

The sslContextCallback overrides modify the SSL_CTX object
itself. That's not useful: in addition to having a global effect,
which is rarely what we want, certain modifications don't apply
that late, and setting the credentials is one of them. The result
is that the callback didn't actually do anything. Not that useful.

Modifications:

- Extract the cert chain setting functions from NIOSSLContext and
    move them to general-purpose functions.
- Add support there for calling them on a SSL * as well as on a
    SSL_CTX *
- Call them in an override function on SSLConnection
- Call that!
- Add a bunch of tests.

Result:

The callback actually does what it should.
  • Loading branch information
Lukasa committed Dec 17, 2024
1 parent 9fc4828 commit 284b780
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 105 deletions.
12 changes: 12 additions & 0 deletions Sources/NIOSSL/SSLConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,18 @@ extension SSLConnection {
return try buffers.map { try NIOSSLCertificate(bytes: $0, format: .der) }
}
}

func applyOverride(_ changes: NIOSSLContextConfigurationOverride) throws {
let connection = UnsafeKeyAndChainTarget.ssl(self.ssl)
if let chain = changes.certificateChain {
try connection.useCertificateChain(chain)
}

// Attempt to load the new private key and abort on failure
if let pkey = changes.privateKey {
try connection.usePrivateKeySource(pkey)
}
}
}

extension SSLConnection.PeerCertificateChainBuffers: RandomAccessCollection {
Expand Down
110 changes: 5 additions & 105 deletions Sources/NIOSSL/SSLContext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,8 @@ private func sslContextCallback(ssl: OpaquePointer?, arg: UnsafeMutableRawPointe
case .success(let changes):
do {
// Attempt to load the new certificate chain and abort on failure
if let chain = changes.certificateChain {
try NIOSSLContext.useCertificateChain(chain, context: parentSwiftContext.sslContext)
}

// Attempt to load the new private key and abort on failure
if let pkey = changes.privateKey {
try NIOSSLContext.usePrivateKeySource(pkey, context: parentSwiftContext.sslContext)
}
let ssl = SSLConnection.loadConnectionFromSSL(ssl)
try ssl.applyOverride(changes)

// We must return 1 to signal a successful load of the new context
return 1
Expand Down Expand Up @@ -450,10 +444,11 @@ public final class NIOSSLContext {
)
}

try NIOSSLContext.useCertificateChain(configuration.certificateChain, context: context)
let handle = UnsafeKeyAndChainTarget.sslContext(context)
try handle.useCertificateChain(configuration.certificateChain)

if let pkey = configuration.privateKey {
try NIOSSLContext.usePrivateKeySource(pkey, context: context)
try handle.usePrivateKeySource(pkey)
}

if configuration.encodedApplicationProtocols.count > 0 {
Expand Down Expand Up @@ -579,101 +574,6 @@ extension NIOSSLContext {
}

extension NIOSSLContext {
fileprivate static func useCertificateChain(
_ certificateChain: [NIOSSLCertificateSource],
context: OpaquePointer
) throws {
var leaf = true
for source in certificateChain {
switch source {
case .file(let p):
NIOSSLContext.useCertificateChainFile(p, context: context)
leaf = false
case .certificate(let cert):
if leaf {
try NIOSSLContext.setLeafCertificate(cert, context: context)
leaf = false
} else {
try NIOSSLContext.addAdditionalChainCertificate(cert, context: context)
}
}
}
}

private static func useCertificateChainFile(_ path: String, context: OpaquePointer) {
// TODO(cory): This shouldn't be an assert but should instead be actual error handling.
// assert(path.isFileURL)
let result = path.withCString { (pointer) -> CInt in
CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(context, pointer)
}

// TODO(cory): again, some error handling would be good.
precondition(result == 1)
}

private static func setLeafCertificate(_ cert: NIOSSLCertificate, context: OpaquePointer) throws {
let rc = cert.withUnsafeMutableX509Pointer { ref in
CNIOBoringSSL_SSL_CTX_use_certificate(context, ref)
}
guard rc == 1 else {
throw NIOSSLError.failedToLoadCertificate
}
}

private static func addAdditionalChainCertificate(_ cert: NIOSSLCertificate, context: OpaquePointer) throws {
let rc = cert.withUnsafeMutableX509Pointer { ref in
CNIOBoringSSL_SSL_CTX_add1_chain_cert(context, ref)
}
guard rc == 1 else {
throw NIOSSLError.failedToLoadCertificate
}
}

fileprivate static func usePrivateKeySource(_ privateKey: NIOSSLPrivateKeySource, context: OpaquePointer) throws {
switch privateKey {
case .file(let p):
try NIOSSLContext.usePrivateKeyFile(p, context: context)
case .privateKey(let key):
try NIOSSLContext.setPrivateKey(key, context: context)
}
}

private static func setPrivateKey(_ key: NIOSSLPrivateKey, context: OpaquePointer) throws {
switch key.representation {
case .native:
let rc = key.withUnsafeMutableEVPPKEYPointer { ref in
CNIOBoringSSL_SSL_CTX_use_PrivateKey(context, ref)
}
guard 1 == rc else {
throw NIOSSLError.failedToLoadPrivateKey
}
case .custom:
CNIOBoringSSL_SSL_CTX_set_private_key_method(context, customPrivateKeyMethod)
}
}

private static func usePrivateKeyFile(_ path: String, context: OpaquePointer) throws {
let pathExtension = path.split(separator: ".").last
let fileType: CInt

switch pathExtension?.lowercased() {
case .some("pem"):
fileType = SSL_FILETYPE_PEM
case .some("der"), .some("key"):
fileType = SSL_FILETYPE_ASN1
default:
throw NIOSSLExtraError.unknownPrivateKeyFileType(path: path)
}

let result = path.withCString { (pointer) -> CInt in
CNIOBoringSSL_SSL_CTX_use_PrivateKey_file(context, pointer, fileType)
}

guard result == 1 else {
throw NIOSSLError.failedToLoadPrivateKey
}
}

private static func setAlpnProtocols(_ protocols: [[UInt8]], context: OpaquePointer) throws {
// This copy should be done infrequently, so we don't worry too much about it.
let protoBuf = protocols.reduce([UInt8](), +)
Expand Down
140 changes: 140 additions & 0 deletions Sources/NIOSSL/UnsafeKeyAndChainTarget.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2024 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
@_implementationOnly import CNIOBoringSSL

enum UnsafeKeyAndChainTarget {
case sslContext(OpaquePointer)
case ssl(OpaquePointer)

func useCertificateChain(
_ certificateChain: [NIOSSLCertificateSource]
) throws {
var leaf = true
for source in certificateChain {
switch source {
case .file(let p):
self.useCertificateChainFile(p)
leaf = false
case .certificate(let cert):
if leaf {
try self.setLeafCertificate(cert)
leaf = false
} else {
try self.addAdditionalChainCertificate(cert)
}
}
}
}

func useCertificateChainFile(_ path: String) {
let result = path.withCString { (pointer) -> CInt in
switch self {
case .sslContext(let context):
CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(context, pointer)
case .ssl(let ssl):
CNIOBoringSSL_SSL_CTX_use_certificate_chain_file(ssl, pointer)
}
}

precondition(result == 1)
}

func setLeafCertificate(_ cert: NIOSSLCertificate) throws {
let rc = cert.withUnsafeMutableX509Pointer { ref in
switch self {
case .sslContext(let context):
CNIOBoringSSL_SSL_CTX_use_certificate(context, ref)
case .ssl(let ssl):
CNIOBoringSSL_SSL_use_certificate(ssl, ref)
}
}
guard rc == 1 else {
throw NIOSSLError.failedToLoadCertificate
}
}

func addAdditionalChainCertificate(_ cert: NIOSSLCertificate) throws {
let rc = cert.withUnsafeMutableX509Pointer { ref in
switch self {
case .sslContext(let context):
CNIOBoringSSL_SSL_CTX_add1_chain_cert(context, ref)
case .ssl(let ssl):
CNIOBoringSSL_SSL_add1_chain_cert(ssl, ref)
}
}
guard rc == 1 else {
throw NIOSSLError.failedToLoadCertificate
}
}

func usePrivateKeySource(_ privateKey: NIOSSLPrivateKeySource) throws {
switch privateKey {
case .file(let p):
try self.usePrivateKeyFile(p)
case .privateKey(let key):
try self.setPrivateKey(key)
}
}

func setPrivateKey(_ key: NIOSSLPrivateKey) throws {
switch key.representation {
case .native:
let rc = key.withUnsafeMutableEVPPKEYPointer { ref in
switch self {
case .sslContext(let context):
CNIOBoringSSL_SSL_CTX_use_PrivateKey(context, ref)
case .ssl(let ssl):
CNIOBoringSSL_SSL_use_PrivateKey(ssl, ref)
}
}
guard 1 == rc else {
throw NIOSSLError.failedToLoadPrivateKey
}
case .custom:
switch self {
case .sslContext(let context):
CNIOBoringSSL_SSL_CTX_set_private_key_method(context, customPrivateKeyMethod)
case .ssl(let ssl):
CNIOBoringSSL_SSL_set_private_key_method(ssl, customPrivateKeyMethod)
}
}
}

func usePrivateKeyFile(_ path: String) throws {
let pathExtension = path.split(separator: ".").last
let fileType: CInt

switch pathExtension?.lowercased() {
case .some("pem"):
fileType = SSL_FILETYPE_PEM
case .some("der"), .some("key"):
fileType = SSL_FILETYPE_ASN1
default:
throw NIOSSLExtraError.unknownPrivateKeyFileType(path: path)
}

let result = path.withCString { (pointer) -> CInt in
switch self {
case .sslContext(let context):
CNIOBoringSSL_SSL_CTX_use_PrivateKey_file(context, pointer, fileType)
case .ssl(let ssl):
CNIOBoringSSL_SSL_use_PrivateKey_file(ssl, pointer, fileType)
}
}

guard result == 1 else {
throw NIOSSLError.failedToLoadPrivateKey
}
}
}
Loading

0 comments on commit 284b780

Please sign in to comment.