From 56af1078adf3e179868a0e893d89c2ea2ca04836 Mon Sep 17 00:00:00 2001 From: Radek Zikmund <32671551+rzikm@users.noreply.github.com> Date: Wed, 28 Feb 2024 15:51:42 +0100 Subject: [PATCH] Don't call user callbacks on MsQuic worker thread. (#98361) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow switching execution profiles using env vars * Quick and dirty version to enable benchmarking * Don't call callbacks from MsQuic threads * Remove unintentional changes * Offload parsing to threadpool as well * Customize TLS ALERT code * Code review feedback * Apply suggestions from code review Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com> * use ConfigureAwaitOptions.ForceYielding * Version check to work around https://github.com/microsoft/msquic/issues/4132 * Use configure await to yield to threadpool * Fix functionality on older msquic versions --------- Co-authored-by: Marie Píchová <11718369+ManickaP@users.noreply.github.com> --- .../Net/Security/CertificateValidation.OSX.cs | 27 ++- .../Security/CertificateValidation.Unix.cs | 2 +- .../Security/CertificateValidation.Windows.cs | 2 +- .../Quic/Internal/MsQuicApi.NativeMethods.cs | 51 ++++++ .../src/System/Net/Quic/Internal/MsQuicApi.cs | 14 +- .../QuicConnection.SslConnectionOptions.cs | 156 ++++++++++++++---- .../src/System/Net/Quic/QuicConnection.cs | 19 ++- .../src/System/Net/Quic/QuicListener.cs | 5 + 8 files changed, 220 insertions(+), 56 deletions(-) diff --git a/src/libraries/Common/src/System/Net/Security/CertificateValidation.OSX.cs b/src/libraries/Common/src/System/Net/Security/CertificateValidation.OSX.cs index aee4b77b508340..b269a0fb70fa5b 100644 --- a/src/libraries/Common/src/System/Net/Security/CertificateValidation.OSX.cs +++ b/src/libraries/Common/src/System/Net/Security/CertificateValidation.OSX.cs @@ -14,7 +14,7 @@ internal static class CertificateValidation private static readonly IdnMapping s_idnMapping = new IdnMapping(); // WARNING: This function will do the verification using OpenSSL. If the intention is to use OS function, caller should use CertificatePal interface. - internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool _ /*isServer*/, string? hostName, IntPtr certificateBuffer, int bufferLength = 0) + internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool _ /*isServer*/, string? hostName, Span certificateBuffer) { SslPolicyErrors errors = chain.Build(remoteCertificate) ? SslPolicyErrors.None : @@ -31,15 +31,24 @@ internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X } SafeX509Handle certHandle; - if (certificateBuffer != IntPtr.Zero && bufferLength > 0) + unsafe { - certHandle = Interop.Crypto.DecodeX509(certificateBuffer, bufferLength); - } - else - { - // We dont't have DER encoded buffer. - byte[] der = remoteCertificate.Export(X509ContentType.Cert); - certHandle = Interop.Crypto.DecodeX509(Marshal.UnsafeAddrOfPinnedArrayElement(der, 0), der.Length); + if (certificateBuffer.Length > 0) + { + fixed (byte* pCert = certificateBuffer) + { + certHandle = Interop.Crypto.DecodeX509((IntPtr)pCert, certificateBuffer.Length); + } + } + else + { + // We dont't have DER encoded buffer. + byte[] der = remoteCertificate.Export(X509ContentType.Cert); + fixed (byte* pDer = der) + { + certHandle = Interop.Crypto.DecodeX509((IntPtr)pDer, der.Length); + } + } } int hostNameMatch; diff --git a/src/libraries/Common/src/System/Net/Security/CertificateValidation.Unix.cs b/src/libraries/Common/src/System/Net/Security/CertificateValidation.Unix.cs index 65a1adb492fa24..da3cb38a868227 100644 --- a/src/libraries/Common/src/System/Net/Security/CertificateValidation.Unix.cs +++ b/src/libraries/Common/src/System/Net/Security/CertificateValidation.Unix.cs @@ -13,7 +13,7 @@ internal static class CertificateValidation private static readonly IdnMapping s_idnMapping = new IdnMapping(); #pragma warning disable IDE0060 - internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, IntPtr certificateBuffer, int bufferLength) + internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, Span certificateBuffer) => BuildChainAndVerifyProperties(chain, remoteCertificate, checkCertName, isServer, hostName); #pragma warning restore IDE0060 diff --git a/src/libraries/Common/src/System/Net/Security/CertificateValidation.Windows.cs b/src/libraries/Common/src/System/Net/Security/CertificateValidation.Windows.cs index d068015e534c49..90be80c734cc72 100644 --- a/src/libraries/Common/src/System/Net/Security/CertificateValidation.Windows.cs +++ b/src/libraries/Common/src/System/Net/Security/CertificateValidation.Windows.cs @@ -14,7 +14,7 @@ namespace System.Net internal static partial class CertificateValidation { #pragma warning disable IDE0060 - internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, IntPtr certificateBuffer, int bufferLength) + internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, Span certificateBuffer) => BuildChainAndVerifyProperties(chain, remoteCertificate, checkCertName, isServer, hostName); #pragma warning restore IDE0060 diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.NativeMethods.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.NativeMethods.cs index 206eac76ac7878..6906392f79ebee 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.NativeMethods.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.NativeMethods.cs @@ -375,4 +375,55 @@ public int StreamReceiveSetEnabled(MsQuicSafeHandle stream, byte enabled) } } } + + public int DatagramSend(MsQuicSafeHandle connection, QUIC_BUFFER* buffers, uint buffersCount, QUIC_SEND_FLAGS flags, void* context) + { + bool success = false; + try + { + connection.DangerousAddRef(ref success); + return ApiTable->DatagramSend(connection.QuicHandle, buffers, buffersCount, flags, context); + } + finally + { + if (success) + { + connection.DangerousRelease(); + } + } + } + + public int ConnectionResumptionTicketValidationComplete(MsQuicSafeHandle connection, byte result) + { + bool success = false; + try + { + connection.DangerousAddRef(ref success); + return ApiTable->ConnectionResumptionTicketValidationComplete(connection.QuicHandle, result); + } + finally + { + if (success) + { + connection.DangerousRelease(); + } + } + } + + public int ConnectionCertificateValidationComplete(MsQuicSafeHandle connection, byte result, QUIC_TLS_ALERT_CODES alert) + { + bool success = false; + try + { + connection.DangerousAddRef(ref success); + return ApiTable->ConnectionCertificateValidationComplete(connection.QuicHandle, result, alert); + } + finally + { + if (success) + { + connection.DangerousRelease(); + } + } + } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs index e89119844c744c..4b284284f5262c 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicApi.cs @@ -54,11 +54,16 @@ private MsQuicApi(QUIC_API_TABLE* apiTable) private static readonly Lazy _api = new Lazy(AllocateMsQuicApi); internal static MsQuicApi Api => _api.Value; + internal static Version? Version { get; private set; } + internal static bool IsQuicSupported { get; } internal static string MsQuicLibraryVersion { get; } = "unknown"; internal static string? NotSupportedReason { get; } + // workaround for https://github.com/microsoft/msquic/issues/4132 + internal static bool SupportsAsyncCertValidation => Version >= new Version(2, 4, 0); + internal static bool UsesSChannelBackend { get; } internal static bool Tls13ServerMayBeDisabled { get; } @@ -69,6 +74,7 @@ static MsQuicApi() { bool loaded = false; IntPtr msQuicHandle; + Version = default; // MsQuic is using DualMode sockets and that will fail even for IPv4 if AF_INET6 is not available. if (!Socket.OSSupportsIPv6) @@ -135,7 +141,7 @@ static MsQuicApi() } return; } - Version version = new Version((int)libVersion[0], (int)libVersion[1], (int)libVersion[2], (int)libVersion[3]); + Version = new Version((int)libVersion[0], (int)libVersion[1], (int)libVersion[2], (int)libVersion[3]); paramSize = 64 * sizeof(sbyte); sbyte* libGitHash = stackalloc sbyte[64]; @@ -150,11 +156,11 @@ static MsQuicApi() } string? gitHash = Marshal.PtrToStringUTF8((IntPtr)libGitHash); - MsQuicLibraryVersion = $"{Interop.Libraries.MsQuic} {version} ({gitHash})"; + MsQuicLibraryVersion = $"{Interop.Libraries.MsQuic} {Version} ({gitHash})"; - if (version < s_minMsQuicVersion) + if (Version < s_minMsQuicVersion) { - NotSupportedReason = $"Incompatible MsQuic library version '{version}', expecting higher than '{s_minMsQuicVersion}'."; + NotSupportedReason = $"Incompatible MsQuic library version '{Version}', expecting higher than '{s_minMsQuicVersion}'."; if (NetEventSource.Log.IsEnabled()) { NetEventSource.Info(null, NotSupportedReason); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.SslConnectionOptions.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.SslConnectionOptions.cs index dad23bfc342c8a..1b352f10045404 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.SslConnectionOptions.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.SslConnectionOptions.cs @@ -1,10 +1,13 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; +using System.Diagnostics; using System.Net.Security; using System.Security.Authentication; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; +using System.Threading.Tasks; using Microsoft.Quic; using static Microsoft.Quic.MsQuic; @@ -63,18 +66,122 @@ public SslConnectionOptions(QuicConnection connection, bool isClient, _certificateChainPolicy = certificateChainPolicy; } - public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER* chainPtr, out X509Certificate2? certificate) + internal async Task StartAsyncCertificateValidation(IntPtr certificatePtr, IntPtr chainPtr) + { + // + // The provided data pointers are valid only while still inside this function, so they need to be + // copied to separate buffers which are then handed off to threadpool. + // + + X509Certificate2? certificate = null; + + byte[]? certDataRented = null; + Memory certData = default; + byte[]? chainDataRented = null; + Memory chainData = default; + + if (certificatePtr != IntPtr.Zero) + { + if (MsQuicApi.UsesSChannelBackend) + { + // provided data is a pointer to a CERT_CONTEXT + certificate = new X509Certificate2(certificatePtr); + // TODO: what about chainPtr? + } + else + { + unsafe + { + // On non-SChannel backends we specify USE_PORTABLE_CERTIFICATES and the contents are buffers + // with DER encoded cert and chain. + QUIC_BUFFER* certificateBuffer = (QUIC_BUFFER*)certificatePtr; + QUIC_BUFFER* chainBuffer = (QUIC_BUFFER*)chainPtr; + + if (certificateBuffer->Length > 0) + { + certDataRented = ArrayPool.Shared.Rent((int)certificateBuffer->Length); + certData = certDataRented.AsMemory(0, (int)certificateBuffer->Length); + certificateBuffer->Span.CopyTo(certData.Span); + } + + if (chainBuffer->Length > 0) + { + chainDataRented = ArrayPool.Shared.Rent((int)chainBuffer->Length); + chainData = chainDataRented.AsMemory(0, (int)chainBuffer->Length); + chainBuffer->Span.CopyTo(chainData.Span); + } + } + } + } + + // We wan't to do the certificate validation asynchronously, but due to a bug in MsQuic, we need to call the callback synchronously on some versions + if (MsQuicApi.SupportsAsyncCertValidation) + { + // force yield to the thread pool to free up MsQuic worker thread. + await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); + } + + // certificatePtr and chainPtr are invalid beyond this point + + QUIC_TLS_ALERT_CODES result; + try + { + if (certData.Length > 0) + { + Debug.Assert(certificate == null); + certificate = new X509Certificate2(certData.Span); + } + + result = _connection._sslConnectionOptions.ValidateCertificate(certificate, certData.Span, chainData.Span); + _connection._remoteCertificate = certificate; + } + catch (Exception ex) + { + certificate?.Dispose(); + _connection._connectedTcs.TrySetException(ex); + result = QUIC_TLS_ALERT_CODES.USER_CANCELED; + } + finally + { + if (certDataRented != null) + { + ArrayPool.Shared.Return(certDataRented); + } + + if (chainDataRented != null) + { + ArrayPool.Shared.Return(chainDataRented); + } + } + + if (MsQuicApi.SupportsAsyncCertValidation) + { + int status = MsQuicApi.Api.ConnectionCertificateValidationComplete( + _connection._handle, + result == QUIC_TLS_ALERT_CODES.SUCCESS ? (byte)1 : (byte)0, + result); + + if (MsQuic.StatusFailed(status)) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Error(_connection, $"{_connection} ConnectionCertificateValidationComplete failed with {ThrowHelper.GetErrorMessageForStatus(status)}"); + } + } + } + + return result == QUIC_TLS_ALERT_CODES.SUCCESS; + } + + private QUIC_TLS_ALERT_CODES ValidateCertificate(X509Certificate2? certificate, Span certData, Span chainData) { SslPolicyErrors sslPolicyErrors = SslPolicyErrors.None; - IntPtr certificateBuffer = 0; - int certificateLength = 0; bool wrapException = false; X509Chain? chain = null; - X509Certificate2? result = null; try { - if (certificatePtr is not null) + if (certificate is not null) { chain = new X509Chain(); if (_certificateChainPolicy != null) @@ -96,43 +203,26 @@ public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER* chain.ChainPolicy.ApplicationPolicy.Add(_isClient ? s_serverAuthOid : s_clientAuthOid); } - if (MsQuicApi.UsesSChannelBackend) + if (chainData.Length > 0) { - result = new X509Certificate2((IntPtr)certificatePtr); + X509Certificate2Collection additionalCertificates = new X509Certificate2Collection(); + additionalCertificates.Import(chainData); + chain.ChainPolicy.ExtraStore.AddRange(additionalCertificates); } - else - { - if (certificatePtr->Length > 0) - { - certificateBuffer = (IntPtr)certificatePtr->Buffer; - certificateLength = (int)certificatePtr->Length; - result = new X509Certificate2(certificatePtr->Span); - } - if (chainPtr->Length > 0) - { - X509Certificate2Collection additionalCertificates = new X509Certificate2Collection(); - additionalCertificates.Import(chainPtr->Span); - chain.ChainPolicy.ExtraStore.AddRange(additionalCertificates); - } - } - } - - if (result is not null) - { bool checkCertName = !chain!.ChainPolicy!.VerificationFlags.HasFlag(X509VerificationFlags.IgnoreInvalidName); - sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, result, checkCertName, !_isClient, TargetHostNameHelper.NormalizeHostName(_targetHost), certificateBuffer, certificateLength); + sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, certificate, checkCertName, !_isClient, TargetHostNameHelper.NormalizeHostName(_targetHost), certData); } else if (_certificateRequired) { sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNotAvailable; } - int status = QUIC_STATUS_SUCCESS; + QUIC_TLS_ALERT_CODES result = QUIC_TLS_ALERT_CODES.SUCCESS; if (_validationCallback is not null) { wrapException = true; - if (!_validationCallback(_connection, result, chain, sslPolicyErrors)) + if (!_validationCallback(_connection, certificate, chain, sslPolicyErrors)) { wrapException = false; if (_isClient) @@ -140,7 +230,7 @@ public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER* throw new AuthenticationException(SR.net_quic_cert_custom_validation); } - status = QUIC_STATUS_USER_CANCELED; + result = QUIC_TLS_ALERT_CODES.BAD_CERTIFICATE; } } else if (sslPolicyErrors != SslPolicyErrors.None) @@ -150,15 +240,13 @@ public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER* throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors)); } - status = QUIC_STATUS_HANDSHAKE_FAILURE; + result = QUIC_TLS_ALERT_CODES.BAD_CERTIFICATE; } - certificate = result; - return status; + return result; } catch (Exception ex) { - result?.Dispose(); if (wrapException) { throw new QuicException(QuicError.CallbackError, null, SR.net_quic_callback_error, ex); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index db3adf776d542c..5a4f626e2f5465 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -571,15 +571,20 @@ private unsafe int HandleEventPeerStreamStarted(ref PEER_STREAM_STARTED_DATA dat } private unsafe int HandleEventPeerCertificateReceived(ref PEER_CERTIFICATE_RECEIVED_DATA data) { - try + // + // The certificate validation is an expensive operation and we don't want to delay MsQuic + // worker thread. So we offload the validation to the .NET threadpool. Incidentally, this + // also prevents potential user RemoteCertificateValidationCallback from blocking MsQuic + // worker threads. + // + + var task = _sslConnectionOptions.StartAsyncCertificateValidation((IntPtr)data.Certificate, (IntPtr)data.Chain); + if (task.IsCompletedSuccessfully) { - return _sslConnectionOptions.ValidateCertificate((QUIC_BUFFER*)data.Certificate, (QUIC_BUFFER*)data.Chain, out _remoteCertificate); - } - catch (Exception ex) - { - _connectedTcs.TrySetException(ex); - return QUIC_STATUS_HANDSHAKE_FAILURE; + return task.Result ? QUIC_STATUS_SUCCESS : QUIC_STATUS_BAD_CERTIFICATE; } + + return QUIC_STATUS_PENDING; } private unsafe int HandleConnectionEvent(ref QUIC_CONNECTION_EVENT connectionEvent) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs index 6f0a0d8bb5b75f..88ea309054a7db 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs @@ -209,6 +209,11 @@ public async ValueTask AcceptConnectionAsync(CancellationToken c /// The TLS ClientHello data. private async void StartConnectionHandshake(QuicConnection connection, SslClientHelloInfo clientHello) { + // Yield to the threadpool immediately. This makes sure the connection options callback + // provided by the user is not invoked from the MsQuic thread and cannot delay acks + // or other operations on other connections. + await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); + bool wrapException = false; CancellationToken cancellationToken = default;